Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
e550bdd0
Unverified
Commit
e550bdd0
authored
May 23, 2019
by
Tianqi Chen
Committed by
GitHub
May 23, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[NODE] Macro to define NodeRef methods, constructor style example (#3224)
parent
e1e91f1f
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
79 additions
and
48 deletions
+79
-48
include/tvm/arithmetic.h
+28
-16
include/tvm/base.h
+32
-21
src/api/api_arith.cc
+10
-3
src/arithmetic/const_int_bound.cc
+5
-5
src/arithmetic/modular_set.cc
+4
-3
No files found.
include/tvm/arithmetic.h
View file @
e550bdd0
...
...
@@ -48,11 +48,7 @@ namespace arith {
// Forward declare Analyzer
class
Analyzer
;
/*!
* \brief reference class to ConstIntBoundNode
* \sa ConstIntBoundNode
*/
class
ConstIntBound
;
/*!
* \brief Constant integer up and lower bound(inclusive).
* Useful for value bound analysis.
...
...
@@ -69,8 +65,6 @@ class ConstIntBoundNode : public Node {
v
->
Visit
(
"max_value"
,
&
max_value
);
}
TVM_DLL
static
ConstIntBound
make
(
int64_t
min_value
,
int64_t
max_value
);
/*! \brief Number to represent +inf */
static
const
constexpr
int64_t
kPosInf
=
std
::
numeric_limits
<
int64_t
>::
max
();
/*!
...
...
@@ -83,7 +77,23 @@ class ConstIntBoundNode : public Node {
TVM_DECLARE_NODE_TYPE_INFO
(
ConstIntBoundNode
,
Node
);
};
TVM_DEFINE_NODE_REF
(
ConstIntBound
,
ConstIntBoundNode
);
/*!
* \brief reference class to ConstIntBoundNode
* \sa ConstIntBoundNode
*/
class
ConstIntBound
:
public
NodeRef
{
public
:
/*!
* \brief constructor by fields.
* \param min_value The mininum value.
* \param max_value The maximum value.
*/
TVM_DLL
ConstIntBound
(
int64_t
min_value
,
int64_t
max_value
);
static
const
constexpr
int64_t
kPosInf
=
ConstIntBoundNode
::
kPosInf
;
static
const
constexpr
int64_t
kNegInf
=
ConstIntBoundNode
::
kNegInf
;
TVM_DEFINE_NODE_REF_METHODS
(
ConstIntBound
,
NodeRef
,
ConstIntBoundNode
);
};
/*!
* \brief Analyzer to get constant integer bound over expression.
...
...
@@ -134,11 +144,6 @@ class ConstIntBoundAnalyzer {
};
/*!
* \brief reference of ModularSetNode
* \sa ModularSetNode
*/
class
ModularSet
;
/*!
* \brief Range of a linear integer function.
* Use to do specify the possible index values.
*
...
...
@@ -162,13 +167,20 @@ class ModularSetNode : public Node {
v
->
Visit
(
"base"
,
&
base
);
}
TVM_DLL
static
ModularSet
make
(
int64_t
coeff
,
int64_t
base
);
static
constexpr
const
char
*
_type_key
=
"arith.ModularSet"
;
TVM_DECLARE_NODE_TYPE_INFO
(
ModularSetNode
,
Node
);
};
TVM_DEFINE_NODE_REF
(
ModularSet
,
ModularSetNode
);
/*!
* \brief reference of ModularSetNode
* \sa ModularSetNode
*/
class
ModularSet
:
public
NodeRef
{
public
:
TVM_DLL
ModularSet
(
int64_t
coeff
,
int64_t
base
);
TVM_DEFINE_NODE_REF_METHODS
(
ModularSet
,
NodeRef
,
ModularSetNode
);
};
/*!
* \brief Analyzer to get modular information over expression.
...
...
include/tvm/base.h
View file @
e550bdd0
...
...
@@ -39,21 +39,24 @@ using ::tvm::Node;
using
::
tvm
::
NodeRef
;
using
::
tvm
::
AttrVisitor
;
/*! \brief Macro to make it easy to define node ref type given node */
#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \
class TypeName : public ::tvm::NodeRef { \
public: \
/*!
* \brief Macro to define common node ref methods.
* \param TypeName The name of the NodeRef.
* \param BaseTypeName The Base type.
* \param NodeName The node container type.
*/
#define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \
TypeName() {} \
explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) {}
\
explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseTypeName(n) {}
\
const NodeName* operator->() const { \
return static_cast<const NodeName*>(node_.get()); \
} \
using ContainerType = NodeName;
\
}; \
operator bool() const { return this->defined(); }
\
using ContainerType = NodeName;
/*!
* \brief Macro to
make it easy to define node ref type that
*
has a CopyOnWrite member function
.
* \brief Macro to
define CopyOnWrite function in a NodeRef.
*
\param NodeName The Type of the Node
.
*
* CopyOnWrite will generate a unique copy of the internal node.
* The node will be copied if it is referenced by multiple places.
...
...
@@ -70,25 +73,33 @@ using ::tvm::AttrVisitor;
*
* \endcode
*/
#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \
class TypeName : public BaseType { \
public: \
TypeName() {} \
explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseType(n) {} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(node_.get()); \
} \
inline NodeName* CopyOnWrite() { \
#define TVM_DEFINE_NODE_REF_COW(NodeName) \
NodeName* CopyOnWrite() { \
CHECK(node_ != nullptr); \
if (!node_.unique()) { \
NodePtr<NodeName> n = make_node<NodeName>(*(operator->())); \
NodePtr<Node>(std::move(n)).swap(node_); \
} \
return static_cast<NodeName*>(node_.get()); \
} \
using ContainerType = NodeName; \
};
}
/*! \brief Macro to make it easy to define node ref type given node */
#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \
class TypeName : public ::tvm::NodeRef { \
public: \
TVM_DEFINE_NODE_REF_METHODS(TypeName, ::tvm::NodeRef, NodeName); \
}; \
/*!
* \brief Macro to make it easy to define node ref type that
* has a CopyOnWrite member function.
*/
#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \
class TypeName : public BaseType { \
public: \
TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseType, NodeName); \
TVM_DEFINE_NODE_REF_COW(NodeName); \
};
/*!
* \brief save the node as well as all the node it depends on as json.
...
...
src/api/api_arith.cc
View file @
e550bdd0
...
...
@@ -58,7 +58,6 @@ TVM_REGISTER_API("arith.DeduceBound")
TVM_REGISTER_API
(
"arith.DomainTouched"
)
.
set_body_typed
(
DomainTouched
);
TVM_REGISTER_API
(
"_IntervalSetGetMin"
)
.
set_body_method
(
&
IntSet
::
min
);
...
...
@@ -71,11 +70,19 @@ TVM_REGISTER_API("_IntSetIsNothing")
TVM_REGISTER_API
(
"_IntSetIsEverything"
)
.
set_body_method
(
&
IntSet
::
is_everything
);
ConstIntBound
MakeConstIntBound
(
int64_t
min_value
,
int64_t
max_value
)
{
return
ConstIntBound
(
min_value
,
max_value
);
}
TVM_REGISTER_API
(
"arith._make_ConstIntBound"
)
.
set_body_typed
(
ConstIntBoundNode
::
make
);
.
set_body_typed
(
MakeConstIntBound
);
ModularSet
MakeModularSet
(
int64_t
coeff
,
int64_t
base
)
{
return
ModularSet
(
coeff
,
base
);
}
TVM_REGISTER_API
(
"arith._make_ModularSet"
)
.
set_body_typed
(
M
odularSetNode
::
make
);
.
set_body_typed
(
M
akeModularSet
);
TVM_REGISTER_API
(
"arith._CreateAnalyzer"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
...
...
src/arithmetic/const_int_bound.cc
View file @
e550bdd0
...
...
@@ -34,12 +34,12 @@ using namespace ir;
TVM_REGISTER_NODE_TYPE
(
ConstIntBoundNode
);
ConstIntBound
ConstIntBoundNode
::
make
(
ConstIntBound
::
ConstIntBound
(
int64_t
min_value
,
int64_t
max_value
)
{
auto
node
=
make_node
<
ConstIntBoundNode
>
();
node
->
min_value
=
min_value
;
node
->
max_value
=
max_value
;
return
ConstIntBound
(
node
);
node_
=
std
::
move
(
node
);
}
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
...
...
@@ -289,8 +289,8 @@ class ConstIntBoundAnalyzer::Impl :
std
::
vector
<
BoundInfo
>
additional_info_
;
// constants: the limit value means umlimited
// NOTE: kNegInf/kPosInf are used to represent infinity.
static
const
constexpr
int64_t
kNegInf
=
ConstIntBound
Node
::
kNegInf
;
static
const
constexpr
int64_t
kPosInf
=
ConstIntBound
Node
::
kPosInf
;
static
const
constexpr
int64_t
kNegInf
=
ConstIntBound
::
kNegInf
;
static
const
constexpr
int64_t
kPosInf
=
ConstIntBound
::
kPosInf
;
static_assert
(
-
kNegInf
==
kPosInf
,
"invariant of inf"
);
// internal helper functions
/*!
...
...
@@ -462,7 +462,7 @@ class ConstIntBoundAnalyzer::Impl :
ConstIntBound
ConstIntBoundAnalyzer
::
operator
()(
const
Expr
&
expr
)
{
Entry
ret
=
impl_
->
VisitExpr
(
expr
);
return
ConstIntBound
Node
::
make
(
ret
.
min_value
,
ret
.
max_value
);
return
ConstIntBound
(
ret
.
min_value
,
ret
.
max_value
);
}
void
ConstIntBoundAnalyzer
::
Update
(
const
Var
&
var
,
...
...
src/arithmetic/modular_set.cc
View file @
e550bdd0
...
...
@@ -35,11 +35,12 @@ using namespace ir;
TVM_REGISTER_NODE_TYPE
(
ModularSetNode
);
ModularSet
ModularSetNode
::
make
(
int64_t
coeff
,
int64_t
base
)
{
ModularSet
::
ModularSet
(
int64_t
coeff
,
int64_t
base
)
{
auto
node
=
make_node
<
ModularSetNode
>
();
node
->
coeff
=
coeff
;
node
->
base
=
base
;
return
ModularSet
(
node
);
// finish construction.
node_
=
std
::
move
(
node
);
}
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
...
...
@@ -372,7 +373,7 @@ class ModularSetAnalyzer::Impl :
ModularSet
ModularSetAnalyzer
::
operator
()(
const
Expr
&
expr
)
{
Entry
ret
=
impl_
->
VisitExpr
(
expr
);
return
ModularSet
Node
::
make
(
ret
.
coeff
,
ret
.
base
);
return
ModularSet
(
ret
.
coeff
,
ret
.
base
);
}
void
ModularSetAnalyzer
::
Update
(
const
Var
&
var
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment