Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
cf0fc361
Commit
cf0fc361
authored
Oct 22, 2016
by
tqchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[REFACTOR] Move Node always bebind NodeRef, expose ->
parent
13383928
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
154 additions
and
128 deletions
+154
-128
include/tvm/domain.h
+96
-68
include/tvm/expr_node.h
+2
-2
include/tvm/tensor.h
+55
-57
src/expr/expr_util.cc
+1
-1
No files found.
include/tvm/domain.h
View file @
cf0fc361
...
...
@@ -13,27 +13,10 @@
namespace
tvm
{
/*! \brief range over one dimension */
class
RangeNode
:
public
Node
{
public
:
/*! \brief beginning of the node */
Expr
begin
;
/*! \brief end of the node */
Expr
end
;
/*! \brief constructor */
RangeNode
()
{}
RangeNode
(
Expr
&&
begin
,
Expr
&&
end
)
:
begin
(
std
::
move
(
begin
)),
end
(
std
::
move
(
end
))
{
}
const
char
*
type_key
()
const
override
{
return
"RangeNode"
;
}
void
VisitNodeRefFields
(
FNodeRefVisit
fvisit
)
override
{
fvisit
(
"begin"
,
&
begin
);
fvisit
(
"end"
,
&
end
);
}
void
VisitAttrs
(
AttrVisitor
*
visitor
)
override
{}
};
// Internal node container of Range
class
RangeNode
;
// Internal node container of RDomain
class
RDomainNode
;
/*! \brief Node range */
class
Range
:
public
NodeRef
{
...
...
@@ -48,14 +31,16 @@ class Range : public NodeRef {
Range
(
Expr
begin
,
Expr
end
);
/*! \return The extent of the range */
Expr
extent
()
const
;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline
const
RangeNode
*
operator
->
()
const
;
/*! \return the begining of the range */
inline
const
Expr
&
begin
()
const
{
return
static_cast
<
const
RangeNode
*>
(
node_
.
get
())
->
begin
;
}
inline
const
Expr
&
begin
()
const
;
/*! \return the end of the range */
inline
const
Expr
&
end
()
const
{
return
static_cast
<
const
RangeNode
*>
(
node_
.
get
())
->
end
;
}
inline
const
Expr
&
end
()
const
;
// overload print function
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Range
&
r
)
{
// NOLINT(*)
os
<<
'['
<<
r
.
begin
()
<<
", "
<<
r
.
end
()
<<
')'
;
return
os
;
...
...
@@ -65,28 +50,6 @@ class Range : public NodeRef {
/*! \brief Domain is a multi-dimensional range */
using
Domain
=
Array
<
Range
>
;
/*! \brief reduction domain node */
class
RDomainNode
:
public
Node
{
public
:
/*! \brief internal index */
Array
<
Var
>
index
;
/*! \brief The inernal domain */
Domain
domain
;
/*! \brief constructor */
RDomainNode
()
{}
RDomainNode
(
Array
<
Var
>
&&
index
,
Domain
&&
domain
)
:
index
(
std
::
move
(
index
)),
domain
(
std
::
move
(
domain
))
{
}
const
char
*
type_key
()
const
override
{
return
"RDomainNode"
;
}
void
VisitNodeRefFields
(
FNodeRefVisit
fvisit
)
override
{
fvisit
(
"index"
,
&
index
);
fvisit
(
"domain"
,
&
domain
);
}
void
VisitAttrs
(
AttrVisitor
*
visitor
)
override
{}
};
/*! \brief reduction domain */
class
RDomain
:
public
NodeRef
{
public
:
...
...
@@ -104,35 +67,27 @@ class RDomain : public NodeRef {
explicit
RDomain
(
std
::
initializer_list
<
Range
>
domain
)
:
RDomain
(
Domain
(
domain
))
{}
/*!
* \brief
constructor from node point
er
* \
param nptr Another node shared point
er
* \brief
access the internal node contain
er
* \
return the pointer to the internal node contain
er
*/
explicit
RDomain
(
std
::
shared_ptr
<
Node
>&&
nptr
)
:
NodeRef
(
std
::
move
(
nptr
))
{
CHECK
(
node_
.
get
()
!=
nullptr
);
CHECK
(
node_
->
is_type
<
RDomainNode
>
());
}
inline
const
RDomainNode
*
operator
->
()
const
;
/*! \return The dimension of the RDomain */
inline
size_t
ndim
()
const
{
return
static_cast
<
const
RDomainNode
*>
(
node_
.
get
())
->
index
.
size
();
}
/*! \return the 0-th index of the domain */
inline
Var
i0
()
const
{
return
index
(
0
);
}
inline
size_t
ndim
()
const
;
/*!
* \param i the index.
* \return i-th index variable in the RDomain
*/
inline
Var
index
(
size_t
i
)
const
{
return
static_cast
<
const
RDomainNode
*>
(
node_
.
get
())
->
index
[
i
];
inline
Var
index
(
size_t
i
)
const
;
/*! \return the 0-th index of the domain */
inline
Var
i0
()
const
{
return
index
(
0
);
}
/*!
* \return The domain of the reduction.
*/
inline
const
Domain
&
domain
()
const
{
return
static_cast
<
const
RDomainNode
*>
(
node_
.
get
())
->
domain
;
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
RDomain
&
r
)
{
// NOLINT(*)
inline
const
Domain
&
domain
()
const
;
// overload print function
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
RDomain
&
r
){
// NOLINT(*)
os
<<
"rdomain("
<<
r
.
domain
()
<<
")"
;
return
os
;
}
...
...
@@ -141,6 +96,79 @@ class RDomain : public NodeRef {
/*! \brief use RDom as alias of RDomain */
using
RDom
=
RDomain
;
/*! \brief range over one dimension */
class
RangeNode
:
public
Node
{
public
:
/*! \brief beginning of the node */
Expr
begin
;
/*! \brief end of the node */
Expr
end
;
/*! \brief constructor */
RangeNode
()
{}
RangeNode
(
Expr
&&
begin
,
Expr
&&
end
)
:
begin
(
std
::
move
(
begin
)),
end
(
std
::
move
(
end
))
{
}
const
char
*
type_key
()
const
override
{
return
"RangeNode"
;
}
void
VisitNodeRefFields
(
FNodeRefVisit
fvisit
)
override
{
fvisit
(
"begin"
,
&
begin
);
fvisit
(
"end"
,
&
end
);
}
void
VisitAttrs
(
AttrVisitor
*
visitor
)
override
{}
};
/*! \brief reduction domain node */
class
RDomainNode
:
public
Node
{
public
:
/*! \brief internal index */
Array
<
Var
>
index
;
/*! \brief The inernal domain */
Domain
domain
;
/*! \brief constructor */
RDomainNode
()
{}
RDomainNode
(
Array
<
Var
>
&&
index
,
Domain
&&
domain
)
:
index
(
std
::
move
(
index
)),
domain
(
std
::
move
(
domain
))
{
}
const
char
*
type_key
()
const
override
{
return
"RDomainNode"
;
}
void
VisitNodeRefFields
(
FNodeRefVisit
fvisit
)
override
{
fvisit
(
"index"
,
&
index
);
fvisit
(
"domain"
,
&
domain
);
}
void
VisitAttrs
(
AttrVisitor
*
visitor
)
override
{}
};
// implements of inline functions
inline
const
RangeNode
*
Range
::
operator
->
()
const
{
return
static_cast
<
const
RangeNode
*>
(
node_
.
get
());
}
inline
const
Expr
&
Range
::
begin
()
const
{
return
(
*
this
)
->
begin
;
}
inline
const
Expr
&
Range
::
end
()
const
{
return
(
*
this
)
->
end
;
}
inline
const
RDomainNode
*
RDomain
::
operator
->
()
const
{
return
static_cast
<
const
RDomainNode
*>
(
node_
.
get
());
}
inline
size_t
RDomain
::
ndim
()
const
{
return
(
*
this
)
->
index
.
size
();
}
inline
Var
RDomain
::
index
(
size_t
i
)
const
{
return
(
*
this
)
->
index
[
i
];
}
inline
const
Domain
&
RDomain
::
domain
()
const
{
return
(
*
this
)
->
domain
;
}
}
// namespace tvm
#endif // TVM_DOMAIN_H_
include/tvm/expr_node.h
View file @
cf0fc361
...
...
@@ -192,7 +192,7 @@ struct TensorReadNode : public ExprNode {
TensorReadNode
(
Tensor
&&
tensor
,
Array
<
Expr
>
&&
indices
)
:
tensor
(
std
::
move
(
tensor
)),
indices
(
std
::
move
(
indices
))
{
node_type_
=
kReduceNode
;
dtype_
=
tensor
.
dtype
()
;
dtype_
=
tensor
->
dtype
;
}
~
TensorReadNode
()
{
this
->
Destroy
();
...
...
@@ -201,7 +201,7 @@ struct TensorReadNode : public ExprNode {
return
"TensorReadNode"
;
}
void
Verify
()
const
override
{
CHECK_EQ
(
dtype_
,
tensor
.
dtype
()
);
CHECK_EQ
(
dtype_
,
tensor
->
dtype
);
for
(
size_t
i
=
0
;
i
<
indices
.
size
();
++
i
)
{
CHECK_EQ
(
indices
[
i
].
dtype
(),
kInt32
);
}
...
...
include/tvm/tensor.h
View file @
cf0fc361
...
...
@@ -15,34 +15,8 @@
namespace
tvm
{
/*! \brief Node to represent a tensor */
class
TensorNode
:
public
Node
{
public
:
/*! \brief optional name of the tensor */
std
::
string
name
;
/*! \brief data type in the content of the tensor */
DataType
dtype
;
/*! \brief The index representing each dimension, used by source expression. */
Array
<
Var
>
dim_index
;
/*! \brief The shape of the tensor */
Array
<
Expr
>
shape
;
/*! \brief source expression */
Expr
source
;
/*! \brief constructor */
TensorNode
()
{}
const
char
*
type_key
()
const
override
{
return
"TensorNode"
;
}
void
VisitAttrs
(
AttrVisitor
*
visitor
)
override
{
visitor
->
Visit
(
"name"
,
&
name
);
visitor
->
Visit
(
"dtype"
,
&
dtype
);
}
void
VisitNodeRefFields
(
FNodeRefVisit
fvisit
)
override
{
fvisit
(
"dim_index"
,
&
dim_index
);
fvisit
(
"shape"
,
&
shape
);
fvisit
(
"source"
,
&
source
);
}
};
// Internal node container of Tensor
class
TensorNode
;
/*! \brief The compute function to specify the input source of a Tensor */
using
FCompute
=
std
::
function
<
Expr
(
const
Array
<
Var
>&
i
)
>
;
...
...
@@ -94,30 +68,13 @@ class Tensor : public NodeRef {
:
Tensor
(
shape
,
GetFCompute
(
f
),
name
)
{}
Tensor
(
Array
<
Expr
>
shape
,
std
::
function
<
Expr
(
Var
,
Var
,
Var
,
Var
)
>
f
,
std
::
string
name
=
"tensor"
)
:
Tensor
(
shape
,
GetFCompute
(
f
),
name
)
{}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline
const
TensorNode
*
operator
->
()
const
;
/*! \return The dimension of the tensor */
inline
size_t
ndim
()
const
{
return
static_cast
<
const
TensorNode
*>
(
node_
.
get
())
->
shape
.
size
();
}
/*! \return The name of the tensor */
inline
const
std
::
string
&
name
()
const
{
return
static_cast
<
const
TensorNode
*>
(
node_
.
get
())
->
name
;
}
/*! \return The data type tensor */
inline
DataType
dtype
()
const
{
return
static_cast
<
const
TensorNode
*>
(
node_
.
get
())
->
dtype
;
}
/*! \return The source expression of intermediate tensor */
inline
const
Expr
&
source
()
const
{
return
static_cast
<
const
TensorNode
*>
(
node_
.
get
())
->
source
;
}
/*! \return The internal dimension index used by source expression */
inline
const
Array
<
Var
>&
dim_index
()
const
{
return
static_cast
<
const
TensorNode
*>
(
node_
.
get
())
->
dim_index
;
}
/*! \return The shape of the tensor */
inline
const
Array
<
Expr
>&
shape
()
const
{
return
static_cast
<
const
TensorNode
*>
(
node_
.
get
())
->
shape
;
}
inline
size_t
ndim
()
const
;
/*!
* \brief Take elements from the tensor
* \param args The indices
...
...
@@ -138,14 +95,55 @@ class Tensor : public NodeRef {
std
::
vector
<
Tensor
>
InputTensors
()
const
;
/*! \return whether the tensor stores a result of reduction */
bool
IsRTensor
()
const
;
// printt function
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Tensor
&
t
)
{
// NOLINT(*)
os
<<
"Tensor(shape="
<<
t
.
shape
()
<<
", source="
<<
t
.
source
()
<<
", name="
<<
t
.
name
()
<<
')'
;
return
os
;
// overload print function
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Tensor
&
t
);
};
/*! \brief Node to represent a tensor */
class
TensorNode
:
public
Node
{
public
:
/*! \brief optional name of the tensor */
std
::
string
name
;
/*! \brief data type in the content of the tensor */
DataType
dtype
;
/*! \brief The index representing each dimension, used by source expression. */
Array
<
Var
>
dim_index
;
/*! \brief The shape of the tensor */
Array
<
Expr
>
shape
;
/*! \brief source expression */
Expr
source
;
/*! \brief constructor */
TensorNode
()
{}
const
char
*
type_key
()
const
override
{
return
"TensorNode"
;
}
void
VisitAttrs
(
AttrVisitor
*
visitor
)
override
{
visitor
->
Visit
(
"name"
,
&
name
);
visitor
->
Visit
(
"dtype"
,
&
dtype
);
}
void
VisitNodeRefFields
(
FNodeRefVisit
fvisit
)
override
{
fvisit
(
"dim_index"
,
&
dim_index
);
fvisit
(
"shape"
,
&
shape
);
fvisit
(
"source"
,
&
source
);
}
};
// implementations
inline
const
TensorNode
*
Tensor
::
operator
->
()
const
{
return
static_cast
<
const
TensorNode
*>
(
node_
.
get
());
}
inline
size_t
Tensor
::
ndim
()
const
{
return
(
*
this
)
->
shape
.
size
();
}
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Tensor
&
t
)
{
// NOLINT(*)
os
<<
"Tensor(shape="
<<
t
->
shape
<<
", source="
<<
t
->
source
<<
", name="
<<
t
->
name
<<
')'
;
return
os
;
}
}
// namespace tvm
#endif // TVM_TENSOR_H_
src/expr/expr_util.cc
View file @
cf0fc361
...
...
@@ -256,7 +256,7 @@ void Expr::Print(std::ostream& os) const {
}
case
kTensorReadNode
:
{
const
auto
*
n
=
Get
<
TensorReadNode
>
();
os
<<
n
->
tensor
.
name
()
<<
n
->
indices
;
os
<<
n
->
tensor
->
name
<<
n
->
indices
;
return
;
}
default
:
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment