Unverified Commit e550bdd0 by Tianqi Chen Committed by GitHub

[NODE] Macro to define NodeRef methods, constructor style example (#3224)

parent e1e91f1f
...@@ -48,11 +48,7 @@ namespace arith { ...@@ -48,11 +48,7 @@ namespace arith {
// Forward declare Analyzer // Forward declare Analyzer
class Analyzer; class Analyzer;
/*!
* \brief reference class to ConstIntBoundNode
* \sa ConstIntBoundNode
*/
class ConstIntBound;
/*! /*!
* \brief Constant integer up and lower bound(inclusive). * \brief Constant integer up and lower bound(inclusive).
* Useful for value bound analysis. * Useful for value bound analysis.
...@@ -69,8 +65,6 @@ class ConstIntBoundNode : public Node { ...@@ -69,8 +65,6 @@ class ConstIntBoundNode : public Node {
v->Visit("max_value", &max_value); v->Visit("max_value", &max_value);
} }
TVM_DLL static ConstIntBound make(int64_t min_value, int64_t max_value);
/*! \brief Number to represent +inf */ /*! \brief Number to represent +inf */
static const constexpr int64_t kPosInf = std::numeric_limits<int64_t>::max(); static const constexpr int64_t kPosInf = std::numeric_limits<int64_t>::max();
/*! /*!
...@@ -83,7 +77,23 @@ class ConstIntBoundNode : public Node { ...@@ -83,7 +77,23 @@ class ConstIntBoundNode : public Node {
TVM_DECLARE_NODE_TYPE_INFO(ConstIntBoundNode, Node); TVM_DECLARE_NODE_TYPE_INFO(ConstIntBoundNode, Node);
}; };
TVM_DEFINE_NODE_REF(ConstIntBound, ConstIntBoundNode); /*!
* \brief reference class to ConstIntBoundNode
* \sa ConstIntBoundNode
*/
class ConstIntBound : public NodeRef {
public:
/*!
* \brief constructor by fields.
* \param min_value The mininum value.
* \param max_value The maximum value.
*/
TVM_DLL ConstIntBound(int64_t min_value, int64_t max_value);
static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf;
static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf;
TVM_DEFINE_NODE_REF_METHODS(ConstIntBound, NodeRef, ConstIntBoundNode);
};
/*! /*!
* \brief Analyzer to get constant integer bound over expression. * \brief Analyzer to get constant integer bound over expression.
...@@ -134,11 +144,6 @@ class ConstIntBoundAnalyzer { ...@@ -134,11 +144,6 @@ class ConstIntBoundAnalyzer {
}; };
/*! /*!
* \brief reference of ModularSetNode
* \sa ModularSetNode
*/
class ModularSet;
/*!
* \brief Range of a linear integer function. * \brief Range of a linear integer function.
* Use to do specify the possible index values. * Use to do specify the possible index values.
* *
...@@ -162,13 +167,20 @@ class ModularSetNode : public Node { ...@@ -162,13 +167,20 @@ class ModularSetNode : public Node {
v->Visit("base", &base); v->Visit("base", &base);
} }
TVM_DLL static ModularSet make(int64_t coeff, int64_t base);
static constexpr const char* _type_key = "arith.ModularSet"; static constexpr const char* _type_key = "arith.ModularSet";
TVM_DECLARE_NODE_TYPE_INFO(ModularSetNode, Node); TVM_DECLARE_NODE_TYPE_INFO(ModularSetNode, Node);
}; };
TVM_DEFINE_NODE_REF(ModularSet, ModularSetNode); /*!
* \brief reference of ModularSetNode
* \sa ModularSetNode
*/
class ModularSet : public NodeRef {
public:
TVM_DLL ModularSet(int64_t coeff, int64_t base);
TVM_DEFINE_NODE_REF_METHODS(ModularSet, NodeRef, ModularSetNode);
};
/*! /*!
* \brief Analyzer to get modular information over expression. * \brief Analyzer to get modular information over expression.
......
...@@ -39,21 +39,24 @@ using ::tvm::Node; ...@@ -39,21 +39,24 @@ using ::tvm::Node;
using ::tvm::NodeRef; using ::tvm::NodeRef;
using ::tvm::AttrVisitor; using ::tvm::AttrVisitor;
/*! \brief Macro to make it easy to define node ref type given node */ /*!
#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \ * \brief Macro to define common node ref methods.
class TypeName : public ::tvm::NodeRef { \ * \param TypeName The name of the NodeRef.
public: \ * \param BaseTypeName The Base type.
TypeName() {} \ * \param NodeName The node container type.
explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) {} \ */
const NodeName* operator->() const { \ #define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \
return static_cast<const NodeName*>(node_.get()); \ TypeName() {} \
} \ explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseTypeName(n) {} \
using ContainerType = NodeName; \ const NodeName* operator->() const { \
}; \ return static_cast<const NodeName*>(node_.get()); \
} \
operator bool() const { return this->defined(); } \
using ContainerType = NodeName;
/*! /*!
* \brief Macro to make it easy to define node ref type that * \brief Macro to define CopyOnWrite function in a NodeRef.
* has a CopyOnWrite member function. * \param NodeName The Type of the Node.
* *
* CopyOnWrite will generate a unique copy of the internal node. * CopyOnWrite will generate a unique copy of the internal node.
* The node will be copied if it is referenced by multiple places. * The node will be copied if it is referenced by multiple places.
...@@ -70,25 +73,33 @@ using ::tvm::AttrVisitor; ...@@ -70,25 +73,33 @@ using ::tvm::AttrVisitor;
* *
* \endcode * \endcode
*/ */
#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \ #define TVM_DEFINE_NODE_REF_COW(NodeName) \
class TypeName : public BaseType { \ NodeName* CopyOnWrite() { \
public: \
TypeName() {} \
explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseType(n) {} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(node_.get()); \
} \
inline NodeName* CopyOnWrite() { \
CHECK(node_ != nullptr); \ CHECK(node_ != nullptr); \
if (!node_.unique()) { \ if (!node_.unique()) { \
NodePtr<NodeName> n = make_node<NodeName>(*(operator->())); \ NodePtr<NodeName> n = make_node<NodeName>(*(operator->())); \
NodePtr<Node>(std::move(n)).swap(node_); \ NodePtr<Node>(std::move(n)).swap(node_); \
} \ } \
return static_cast<NodeName*>(node_.get()); \ return static_cast<NodeName*>(node_.get()); \
} \ }
using ContainerType = NodeName; \
};
/*! \brief Macro to make it easy to define node ref type given node */
#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \
class TypeName : public ::tvm::NodeRef { \
public: \
TVM_DEFINE_NODE_REF_METHODS(TypeName, ::tvm::NodeRef, NodeName); \
}; \
/*!
* \brief Macro to make it easy to define node ref type that
* has a CopyOnWrite member function.
*/
#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \
class TypeName : public BaseType { \
public: \
TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseType, NodeName); \
TVM_DEFINE_NODE_REF_COW(NodeName); \
};
/*! /*!
* \brief save the node as well as all the node it depends on as json. * \brief save the node as well as all the node it depends on as json.
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -58,7 +58,6 @@ TVM_REGISTER_API("arith.DeduceBound") ...@@ -58,7 +58,6 @@ TVM_REGISTER_API("arith.DeduceBound")
TVM_REGISTER_API("arith.DomainTouched") TVM_REGISTER_API("arith.DomainTouched")
.set_body_typed(DomainTouched); .set_body_typed(DomainTouched);
TVM_REGISTER_API("_IntervalSetGetMin") TVM_REGISTER_API("_IntervalSetGetMin")
.set_body_method(&IntSet::min); .set_body_method(&IntSet::min);
...@@ -71,11 +70,19 @@ TVM_REGISTER_API("_IntSetIsNothing") ...@@ -71,11 +70,19 @@ TVM_REGISTER_API("_IntSetIsNothing")
TVM_REGISTER_API("_IntSetIsEverything") TVM_REGISTER_API("_IntSetIsEverything")
.set_body_method(&IntSet::is_everything); .set_body_method(&IntSet::is_everything);
ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) {
return ConstIntBound(min_value, max_value);
}
TVM_REGISTER_API("arith._make_ConstIntBound") TVM_REGISTER_API("arith._make_ConstIntBound")
.set_body_typed(ConstIntBoundNode::make); .set_body_typed(MakeConstIntBound);
ModularSet MakeModularSet(int64_t coeff, int64_t base) {
return ModularSet(coeff, base);
}
TVM_REGISTER_API("arith._make_ModularSet") TVM_REGISTER_API("arith._make_ModularSet")
.set_body_typed(ModularSetNode::make); .set_body_typed(MakeModularSet);
TVM_REGISTER_API("arith._CreateAnalyzer") TVM_REGISTER_API("arith._CreateAnalyzer")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
......
...@@ -34,12 +34,12 @@ using namespace ir; ...@@ -34,12 +34,12 @@ using namespace ir;
TVM_REGISTER_NODE_TYPE(ConstIntBoundNode); TVM_REGISTER_NODE_TYPE(ConstIntBoundNode);
ConstIntBound ConstIntBoundNode::make( ConstIntBound::ConstIntBound(
int64_t min_value, int64_t max_value) { int64_t min_value, int64_t max_value) {
auto node = make_node<ConstIntBoundNode>(); auto node = make_node<ConstIntBoundNode>();
node->min_value = min_value; node->min_value = min_value;
node->max_value = max_value; node->max_value = max_value;
return ConstIntBound(node); node_ = std::move(node);
} }
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
...@@ -289,8 +289,8 @@ class ConstIntBoundAnalyzer::Impl : ...@@ -289,8 +289,8 @@ class ConstIntBoundAnalyzer::Impl :
std::vector<BoundInfo> additional_info_; std::vector<BoundInfo> additional_info_;
// constants: the limit value means umlimited // constants: the limit value means umlimited
// NOTE: kNegInf/kPosInf are used to represent infinity. // NOTE: kNegInf/kPosInf are used to represent infinity.
static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf; static const constexpr int64_t kNegInf = ConstIntBound::kNegInf;
static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf; static const constexpr int64_t kPosInf = ConstIntBound::kPosInf;
static_assert(-kNegInf == kPosInf, "invariant of inf"); static_assert(-kNegInf == kPosInf, "invariant of inf");
// internal helper functions // internal helper functions
/*! /*!
...@@ -462,7 +462,7 @@ class ConstIntBoundAnalyzer::Impl : ...@@ -462,7 +462,7 @@ class ConstIntBoundAnalyzer::Impl :
ConstIntBound ConstIntBoundAnalyzer::operator()(const Expr& expr) { ConstIntBound ConstIntBoundAnalyzer::operator()(const Expr& expr) {
Entry ret = impl_->VisitExpr(expr); Entry ret = impl_->VisitExpr(expr);
return ConstIntBoundNode::make(ret.min_value, ret.max_value); return ConstIntBound(ret.min_value, ret.max_value);
} }
void ConstIntBoundAnalyzer::Update(const Var& var, void ConstIntBoundAnalyzer::Update(const Var& var,
......
...@@ -35,11 +35,12 @@ using namespace ir; ...@@ -35,11 +35,12 @@ using namespace ir;
TVM_REGISTER_NODE_TYPE(ModularSetNode); TVM_REGISTER_NODE_TYPE(ModularSetNode);
ModularSet ModularSetNode::make(int64_t coeff, int64_t base) { ModularSet::ModularSet(int64_t coeff, int64_t base) {
auto node = make_node<ModularSetNode>(); auto node = make_node<ModularSetNode>();
node->coeff = coeff; node->coeff = coeff;
node->base = base; node->base = base;
return ModularSet(node); // finish construction.
node_ = std::move(node);
} }
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
...@@ -366,13 +367,13 @@ class ModularSetAnalyzer::Impl : ...@@ -366,13 +367,13 @@ class ModularSetAnalyzer::Impl :
* \return Bound that represent everything dtype can represent. * \return Bound that represent everything dtype can represent.
*/ */
static Entry Nothing() { static Entry Nothing() {
return Entry(0, 1); return Entry(0, 1);
} }
}; };
ModularSet ModularSetAnalyzer::operator()(const Expr& expr) { ModularSet ModularSetAnalyzer::operator()(const Expr& expr) {
Entry ret = impl_->VisitExpr(expr); Entry ret = impl_->VisitExpr(expr);
return ModularSetNode::make(ret.coeff, ret.base); return ModularSet(ret.coeff, ret.base);
} }
void ModularSetAnalyzer::Update(const Var& var, void ModularSetAnalyzer::Update(const Var& var,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment