Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
e550bdd0
Unverified
Commit
e550bdd0
authored
May 23, 2019
by
Tianqi Chen
Committed by
GitHub
May 23, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[NODE] Macro to define NodeRef methods, constructor style example (#3224)
parent
e1e91f1f
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
86 additions
and
55 deletions
+86
-55
include/tvm/arithmetic.h
+28
-16
include/tvm/base.h
+36
-25
src/api/api_arith.cc
+12
-5
src/arithmetic/const_int_bound.cc
+5
-5
src/arithmetic/modular_set.cc
+5
-4
No files found.
include/tvm/arithmetic.h
View file @
e550bdd0
...
@@ -48,11 +48,7 @@ namespace arith {
...
@@ -48,11 +48,7 @@ namespace arith {
// Forward declare Analyzer
// Forward declare Analyzer
class
Analyzer
;
class
Analyzer
;
/*!
* \brief reference class to ConstIntBoundNode
* \sa ConstIntBoundNode
*/
class
ConstIntBound
;
/*!
/*!
* \brief Constant integer up and lower bound(inclusive).
* \brief Constant integer up and lower bound(inclusive).
* Useful for value bound analysis.
* Useful for value bound analysis.
...
@@ -69,8 +65,6 @@ class ConstIntBoundNode : public Node {
...
@@ -69,8 +65,6 @@ class ConstIntBoundNode : public Node {
v
->
Visit
(
"max_value"
,
&
max_value
);
v
->
Visit
(
"max_value"
,
&
max_value
);
}
}
TVM_DLL
static
ConstIntBound
make
(
int64_t
min_value
,
int64_t
max_value
);
/*! \brief Number to represent +inf */
/*! \brief Number to represent +inf */
static
const
constexpr
int64_t
kPosInf
=
std
::
numeric_limits
<
int64_t
>::
max
();
static
const
constexpr
int64_t
kPosInf
=
std
::
numeric_limits
<
int64_t
>::
max
();
/*!
/*!
...
@@ -83,7 +77,23 @@ class ConstIntBoundNode : public Node {
...
@@ -83,7 +77,23 @@ class ConstIntBoundNode : public Node {
TVM_DECLARE_NODE_TYPE_INFO
(
ConstIntBoundNode
,
Node
);
TVM_DECLARE_NODE_TYPE_INFO
(
ConstIntBoundNode
,
Node
);
};
};
TVM_DEFINE_NODE_REF
(
ConstIntBound
,
ConstIntBoundNode
);
/*!
* \brief reference class to ConstIntBoundNode
* \sa ConstIntBoundNode
*/
class
ConstIntBound
:
public
NodeRef
{
public
:
/*!
* \brief constructor by fields.
* \param min_value The mininum value.
* \param max_value The maximum value.
*/
TVM_DLL
ConstIntBound
(
int64_t
min_value
,
int64_t
max_value
);
static
const
constexpr
int64_t
kPosInf
=
ConstIntBoundNode
::
kPosInf
;
static
const
constexpr
int64_t
kNegInf
=
ConstIntBoundNode
::
kNegInf
;
TVM_DEFINE_NODE_REF_METHODS
(
ConstIntBound
,
NodeRef
,
ConstIntBoundNode
);
};
/*!
/*!
* \brief Analyzer to get constant integer bound over expression.
* \brief Analyzer to get constant integer bound over expression.
...
@@ -134,11 +144,6 @@ class ConstIntBoundAnalyzer {
...
@@ -134,11 +144,6 @@ class ConstIntBoundAnalyzer {
};
};
/*!
/*!
* \brief reference of ModularSetNode
* \sa ModularSetNode
*/
class
ModularSet
;
/*!
* \brief Range of a linear integer function.
* \brief Range of a linear integer function.
* Use to do specify the possible index values.
* Use to do specify the possible index values.
*
*
...
@@ -162,13 +167,20 @@ class ModularSetNode : public Node {
...
@@ -162,13 +167,20 @@ class ModularSetNode : public Node {
v
->
Visit
(
"base"
,
&
base
);
v
->
Visit
(
"base"
,
&
base
);
}
}
TVM_DLL
static
ModularSet
make
(
int64_t
coeff
,
int64_t
base
);
static
constexpr
const
char
*
_type_key
=
"arith.ModularSet"
;
static
constexpr
const
char
*
_type_key
=
"arith.ModularSet"
;
TVM_DECLARE_NODE_TYPE_INFO
(
ModularSetNode
,
Node
);
TVM_DECLARE_NODE_TYPE_INFO
(
ModularSetNode
,
Node
);
};
};
TVM_DEFINE_NODE_REF
(
ModularSet
,
ModularSetNode
);
/*!
* \brief reference of ModularSetNode
* \sa ModularSetNode
*/
class
ModularSet
:
public
NodeRef
{
public
:
TVM_DLL
ModularSet
(
int64_t
coeff
,
int64_t
base
);
TVM_DEFINE_NODE_REF_METHODS
(
ModularSet
,
NodeRef
,
ModularSetNode
);
};
/*!
/*!
* \brief Analyzer to get modular information over expression.
* \brief Analyzer to get modular information over expression.
...
...
include/tvm/base.h
View file @
e550bdd0
...
@@ -39,21 +39,24 @@ using ::tvm::Node;
...
@@ -39,21 +39,24 @@ using ::tvm::Node;
using
::
tvm
::
NodeRef
;
using
::
tvm
::
NodeRef
;
using
::
tvm
::
AttrVisitor
;
using
::
tvm
::
AttrVisitor
;
/*! \brief Macro to make it easy to define node ref type given node */
/*!
#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \
* \brief Macro to define common node ref methods.
class TypeName : public ::tvm::NodeRef { \
* \param TypeName The name of the NodeRef.
public: \
* \param BaseTypeName The Base type.
TypeName() {} \
* \param NodeName The node container type.
explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) {} \
*/
const NodeName* operator->() const { \
#define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \
return static_cast<const NodeName*>(node_.get()); \
TypeName() {} \
} \
explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseTypeName(n) {} \
using ContainerType = NodeName; \
const NodeName* operator->() const { \
}; \
return static_cast<const NodeName*>(node_.get()); \
} \
operator bool() const { return this->defined(); } \
using ContainerType = NodeName;
/*!
/*!
* \brief Macro to
make it easy to define node ref type that
* \brief Macro to
define CopyOnWrite function in a NodeRef.
*
has a CopyOnWrite member function
.
*
\param NodeName The Type of the Node
.
*
*
* CopyOnWrite will generate a unique copy of the internal node.
* CopyOnWrite will generate a unique copy of the internal node.
* The node will be copied if it is referenced by multiple places.
* The node will be copied if it is referenced by multiple places.
...
@@ -70,25 +73,33 @@ using ::tvm::AttrVisitor;
...
@@ -70,25 +73,33 @@ using ::tvm::AttrVisitor;
*
*
* \endcode
* \endcode
*/
*/
#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \
#define TVM_DEFINE_NODE_REF_COW(NodeName) \
class TypeName : public BaseType { \
NodeName* CopyOnWrite() { \
public: \
TypeName() {} \
explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseType(n) {} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(node_.get()); \
} \
inline NodeName* CopyOnWrite() { \
CHECK(node_ != nullptr); \
CHECK(node_ != nullptr); \
if (!node_.unique()) { \
if (!node_.unique()) { \
NodePtr<NodeName> n = make_node<NodeName>(*(operator->())); \
NodePtr<NodeName> n = make_node<NodeName>(*(operator->())); \
NodePtr<Node>(std::move(n)).swap(node_); \
NodePtr<Node>(std::move(n)).swap(node_); \
} \
} \
return static_cast<NodeName*>(node_.get()); \
return static_cast<NodeName*>(node_.get()); \
} \
}
using ContainerType = NodeName; \
};
/*! \brief Macro to make it easy to define node ref type given node */
#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \
class TypeName : public ::tvm::NodeRef { \
public: \
TVM_DEFINE_NODE_REF_METHODS(TypeName, ::tvm::NodeRef, NodeName); \
}; \
/*!
* \brief Macro to make it easy to define node ref type that
* has a CopyOnWrite member function.
*/
#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \
class TypeName : public BaseType { \
public: \
TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseType, NodeName); \
TVM_DEFINE_NODE_REF_COW(NodeName); \
};
/*!
/*!
* \brief save the node as well as all the node it depends on as json.
* \brief save the node as well as all the node it depends on as json.
...
...
src/api/api_arith.cc
View file @
e550bdd0
...
@@ -6,9 +6,9 @@
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
@@ -58,7 +58,6 @@ TVM_REGISTER_API("arith.DeduceBound")
...
@@ -58,7 +58,6 @@ TVM_REGISTER_API("arith.DeduceBound")
TVM_REGISTER_API
(
"arith.DomainTouched"
)
TVM_REGISTER_API
(
"arith.DomainTouched"
)
.
set_body_typed
(
DomainTouched
);
.
set_body_typed
(
DomainTouched
);
TVM_REGISTER_API
(
"_IntervalSetGetMin"
)
TVM_REGISTER_API
(
"_IntervalSetGetMin"
)
.
set_body_method
(
&
IntSet
::
min
);
.
set_body_method
(
&
IntSet
::
min
);
...
@@ -71,11 +70,19 @@ TVM_REGISTER_API("_IntSetIsNothing")
...
@@ -71,11 +70,19 @@ TVM_REGISTER_API("_IntSetIsNothing")
TVM_REGISTER_API
(
"_IntSetIsEverything"
)
TVM_REGISTER_API
(
"_IntSetIsEverything"
)
.
set_body_method
(
&
IntSet
::
is_everything
);
.
set_body_method
(
&
IntSet
::
is_everything
);
ConstIntBound
MakeConstIntBound
(
int64_t
min_value
,
int64_t
max_value
)
{
return
ConstIntBound
(
min_value
,
max_value
);
}
TVM_REGISTER_API
(
"arith._make_ConstIntBound"
)
TVM_REGISTER_API
(
"arith._make_ConstIntBound"
)
.
set_body_typed
(
ConstIntBoundNode
::
make
);
.
set_body_typed
(
MakeConstIntBound
);
ModularSet
MakeModularSet
(
int64_t
coeff
,
int64_t
base
)
{
return
ModularSet
(
coeff
,
base
);
}
TVM_REGISTER_API
(
"arith._make_ModularSet"
)
TVM_REGISTER_API
(
"arith._make_ModularSet"
)
.
set_body_typed
(
M
odularSetNode
::
make
);
.
set_body_typed
(
M
akeModularSet
);
TVM_REGISTER_API
(
"arith._CreateAnalyzer"
)
TVM_REGISTER_API
(
"arith._CreateAnalyzer"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
...
...
src/arithmetic/const_int_bound.cc
View file @
e550bdd0
...
@@ -34,12 +34,12 @@ using namespace ir;
...
@@ -34,12 +34,12 @@ using namespace ir;
TVM_REGISTER_NODE_TYPE
(
ConstIntBoundNode
);
TVM_REGISTER_NODE_TYPE
(
ConstIntBoundNode
);
ConstIntBound
ConstIntBoundNode
::
make
(
ConstIntBound
::
ConstIntBound
(
int64_t
min_value
,
int64_t
max_value
)
{
int64_t
min_value
,
int64_t
max_value
)
{
auto
node
=
make_node
<
ConstIntBoundNode
>
();
auto
node
=
make_node
<
ConstIntBoundNode
>
();
node
->
min_value
=
min_value
;
node
->
min_value
=
min_value
;
node
->
max_value
=
max_value
;
node
->
max_value
=
max_value
;
return
ConstIntBound
(
node
);
node_
=
std
::
move
(
node
);
}
}
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
...
@@ -289,8 +289,8 @@ class ConstIntBoundAnalyzer::Impl :
...
@@ -289,8 +289,8 @@ class ConstIntBoundAnalyzer::Impl :
std
::
vector
<
BoundInfo
>
additional_info_
;
std
::
vector
<
BoundInfo
>
additional_info_
;
// constants: the limit value means umlimited
// constants: the limit value means umlimited
// NOTE: kNegInf/kPosInf are used to represent infinity.
// NOTE: kNegInf/kPosInf are used to represent infinity.
static
const
constexpr
int64_t
kNegInf
=
ConstIntBound
Node
::
kNegInf
;
static
const
constexpr
int64_t
kNegInf
=
ConstIntBound
::
kNegInf
;
static
const
constexpr
int64_t
kPosInf
=
ConstIntBound
Node
::
kPosInf
;
static
const
constexpr
int64_t
kPosInf
=
ConstIntBound
::
kPosInf
;
static_assert
(
-
kNegInf
==
kPosInf
,
"invariant of inf"
);
static_assert
(
-
kNegInf
==
kPosInf
,
"invariant of inf"
);
// internal helper functions
// internal helper functions
/*!
/*!
...
@@ -462,7 +462,7 @@ class ConstIntBoundAnalyzer::Impl :
...
@@ -462,7 +462,7 @@ class ConstIntBoundAnalyzer::Impl :
ConstIntBound
ConstIntBoundAnalyzer
::
operator
()(
const
Expr
&
expr
)
{
ConstIntBound
ConstIntBoundAnalyzer
::
operator
()(
const
Expr
&
expr
)
{
Entry
ret
=
impl_
->
VisitExpr
(
expr
);
Entry
ret
=
impl_
->
VisitExpr
(
expr
);
return
ConstIntBound
Node
::
make
(
ret
.
min_value
,
ret
.
max_value
);
return
ConstIntBound
(
ret
.
min_value
,
ret
.
max_value
);
}
}
void
ConstIntBoundAnalyzer
::
Update
(
const
Var
&
var
,
void
ConstIntBoundAnalyzer
::
Update
(
const
Var
&
var
,
...
...
src/arithmetic/modular_set.cc
View file @
e550bdd0
...
@@ -35,11 +35,12 @@ using namespace ir;
...
@@ -35,11 +35,12 @@ using namespace ir;
TVM_REGISTER_NODE_TYPE
(
ModularSetNode
);
TVM_REGISTER_NODE_TYPE
(
ModularSetNode
);
ModularSet
ModularSetNode
::
make
(
int64_t
coeff
,
int64_t
base
)
{
ModularSet
::
ModularSet
(
int64_t
coeff
,
int64_t
base
)
{
auto
node
=
make_node
<
ModularSetNode
>
();
auto
node
=
make_node
<
ModularSetNode
>
();
node
->
coeff
=
coeff
;
node
->
coeff
=
coeff
;
node
->
base
=
base
;
node
->
base
=
base
;
return
ModularSet
(
node
);
// finish construction.
node_
=
std
::
move
(
node
);
}
}
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
...
@@ -366,13 +367,13 @@ class ModularSetAnalyzer::Impl :
...
@@ -366,13 +367,13 @@ class ModularSetAnalyzer::Impl :
* \return Bound that represent everything dtype can represent.
* \return Bound that represent everything dtype can represent.
*/
*/
static
Entry
Nothing
()
{
static
Entry
Nothing
()
{
return
Entry
(
0
,
1
);
return
Entry
(
0
,
1
);
}
}
};
};
ModularSet
ModularSetAnalyzer
::
operator
()(
const
Expr
&
expr
)
{
ModularSet
ModularSetAnalyzer
::
operator
()(
const
Expr
&
expr
)
{
Entry
ret
=
impl_
->
VisitExpr
(
expr
);
Entry
ret
=
impl_
->
VisitExpr
(
expr
);
return
ModularSet
Node
::
make
(
ret
.
coeff
,
ret
.
base
);
return
ModularSet
(
ret
.
coeff
,
ret
.
base
);
}
}
void
ModularSetAnalyzer
::
Update
(
const
Var
&
var
,
void
ModularSetAnalyzer
::
Update
(
const
Var
&
var
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment