Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
56e10eb0
Commit
56e10eb0
authored
Oct 19, 2016
by
tqchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Tensor API
parent
5f829774
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
199 additions
and
17 deletions
+199
-17
include/tvm/array.h
+2
-2
include/tvm/expr_node.h
+38
-1
include/tvm/tensor.h
+94
-10
src/expr/domain.cc
+3
-1
src/expr/expr.cc
+5
-0
src/expr/expr_node.cc
+1
-0
src/expr/tensor.cc
+48
-0
tests/cpp/tensor_test.cc
+8
-3
No files found.
include/tvm/array.h
View file @
56e10eb0
...
...
@@ -23,10 +23,10 @@ class ArrayNode : public Node {
return
"ArrayNode"
;
}
void
VisitAttrs
(
AttrVisitor
*
visitor
)
override
{
LOG
(
FATAL
)
<<
"need to specially handle list"
;
LOG
(
FATAL
)
<<
"need to specially handle list
attrs
"
;
}
void
VisitNodeRefFields
(
FNodeRefVisit
fvisit
)
override
{
LOG
(
FATAL
)
<<
"need to specially handle list"
;
// Do nothing, specially handled
}
};
...
...
include/tvm/expr_node.h
View file @
56e10eb0
...
...
@@ -141,7 +141,7 @@ struct BinaryOpNode : public ExprNode {
}
};
/*! \brief
Binary mapping
operator */
/*! \brief
Reduction operator
operator */
struct
ReduceNode
:
public
ExprNode
{
public
:
/*! \brief The operator */
...
...
@@ -178,6 +178,43 @@ struct ReduceNode : public ExprNode {
}
};
/*! \brief Tensor read operator */
struct
TensorReadNode
:
public
ExprNode
{
public
:
/*! \brief The tensor to be read from */
Tensor
tensor
;
/*! \brief The indices of read */
Array
<
Expr
>
indices
;
/*! \brief constructor, do not use constructor */
TensorReadNode
()
{
node_type_
=
kTensorReadNode
;
}
TensorReadNode
(
Tensor
&&
tensor
,
Array
<
Expr
>
&&
indices
)
:
tensor
(
std
::
move
(
tensor
)),
indices
(
std
::
move
(
indices
))
{
node_type_
=
kReduceNode
;
dtype_
=
tensor
.
dtype
();
}
~
TensorReadNode
()
{
this
->
Destroy
();
}
const
char
*
type_key
()
const
override
{
return
"TensorReadNode"
;
}
void
Verify
()
const
override
{
CHECK_EQ
(
dtype_
,
tensor
.
dtype
());
for
(
size_t
i
=
0
;
i
<
indices
.
size
();
++
i
)
{
CHECK_EQ
(
indices
[
i
].
dtype
(),
kInt32
);
}
}
void
VisitAttrs
(
AttrVisitor
*
visitor
)
override
{
visitor
->
Visit
(
"dtype"
,
&
dtype_
);
}
void
VisitNodeRefFields
(
FNodeRefVisit
fvisit
)
override
{
fvisit
(
"tensor"
,
&
tensor
);
fvisit
(
"indices"
,
&
indices
);
}
};
}
// namespace tvm
#endif // TVM_EXPR_NODE_H_
include/tvm/tensor.h
View file @
56e10eb0
...
...
@@ -7,6 +7,7 @@
#define TVM_TENSOR_H_
#include <string>
#include <type_traits>
#include "./expr.h"
#include "./array.h"
...
...
@@ -19,15 +20,14 @@ class TensorNode : public Node {
std
::
string
name
;
/*! \brief data type in the content of the tensor */
DataType
dtype
;
/*! \brief The index
on each dimension
*/
/*! \brief The index
representing each dimension, used by source expression.
*/
Array
<
Var
>
dim_index
;
/*! \brief The shape of the tensor */
Array
<
Expr
>
shape
;
/*! \brief source expression */
Expr
source
;
/*! \brief constructor */
TensorNode
()
{
}
TensorNode
()
{}
const
char
*
type_key
()
const
override
{
return
"TensorNode"
;
}
...
...
@@ -42,20 +42,104 @@ class TensorNode : public Node {
}
};
/*! \brief The compute function to specify the input source of a Tensor */
using
FCompute
=
std
::
function
<
Expr
(
const
Array
<
Var
>&
i
)
>
;
// converters from other functions into fcompute
inline
FCompute
GetFCompute
(
std
::
function
<
Expr
(
Var
x
)
>
f
)
{
return
[
f
](
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
]);
};
}
inline
FCompute
GetFCompute
(
std
::
function
<
Expr
(
Var
,
Var
)
>
f
)
{
return
[
f
](
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
],
i
[
1
]);
};
}
inline
FCompute
GetFCompute
(
std
::
function
<
Expr
(
Var
,
Var
,
Var
)
>
f
)
{
return
[
f
](
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
],
i
[
1
],
i
[
2
]);
};
}
inline
FCompute
GetFCompute
(
std
::
function
<
Expr
(
Var
,
Var
,
Var
,
Var
)
>
f
)
{
return
[
f
](
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
],
i
[
1
],
i
[
2
],
i
[
3
]);
};
}
/*!
* \brief Tensor structure representing a possible input,
* or intermediate computation result.
*/
class
Tensor
:
public
NodeRef
{
public
:
explicit
Tensor
(
Array
<
Expr
>
shape
);
inline
size_t
ndim
()
const
;
/*! \brief default constructor, used internally */
Tensor
()
{}
/*!
* \brief constructor of input tensor
* \param shape Shape of the tensor.
* \param name optional name of the Tensor.
* \param dtype The data type of the input tensor.
*/
explicit
Tensor
(
Array
<
Expr
>
shape
,
std
::
string
name
=
"tensor"
,
DataType
dtype
=
kFloat32
);
/*!
* \brief constructor of intermediate result.
* \param shape Shape of the tensor.
* \param fcompute The compute function to create the tensor.
* \param name The optional name of the tensor.
*/
Tensor
(
Array
<
Expr
>
shape
,
FCompute
fcompute
,
std
::
string
name
=
"tensor"
);
// same constructor, specialized for different fcompute function
Tensor
(
Array
<
Expr
>
shape
,
std
::
function
<
Expr
(
Var
)
>
f
,
std
::
string
name
=
"tensor"
)
:
Tensor
(
shape
,
GetFCompute
(
f
),
name
)
{}
Tensor
(
Array
<
Expr
>
shape
,
std
::
function
<
Expr
(
Var
,
Var
)
>
f
,
std
::
string
name
=
"tensor"
)
:
Tensor
(
shape
,
GetFCompute
(
f
),
name
)
{}
Tensor
(
Array
<
Expr
>
shape
,
std
::
function
<
Expr
(
Var
,
Var
,
Var
)
>
f
,
std
::
string
name
=
"tensor"
)
:
Tensor
(
shape
,
GetFCompute
(
f
),
name
)
{}
Tensor
(
Array
<
Expr
>
shape
,
std
::
function
<
Expr
(
Var
,
Var
,
Var
,
Var
)
>
f
,
std
::
string
name
=
"tensor"
)
:
Tensor
(
shape
,
GetFCompute
(
f
),
name
)
{}
/*! \return The dimension of the tensor */
inline
size_t
ndim
()
const
{
return
static_cast
<
const
TensorNode
*>
(
node_
.
get
())
->
shape
.
size
();
}
/*! \return The name of the tensor */
inline
const
std
::
string
&
name
()
const
{
return
static_cast
<
const
TensorNode
*>
(
node_
.
get
())
->
name
;
}
/*! \return The data type tensor */
inline
DataType
dtype
()
const
{
return
static_cast
<
const
TensorNode
*>
(
node_
.
get
())
->
dtype
;
}
/*! \return The source expression of intermediate tensor */
inline
const
Expr
&
source
()
const
{
return
static_cast
<
const
TensorNode
*>
(
node_
.
get
())
->
source
;
}
/*! \return The internal dimension index used by source expression */
inline
const
Array
<
Var
>&
dim_index
()
const
{
return
static_cast
<
const
TensorNode
*>
(
node_
.
get
())
->
dim_index
;
}
/*! \return The shape of the tensor */
inline
const
Array
<
Expr
>&
shape
()
const
{
return
static_cast
<
const
TensorNode
*>
(
node_
.
get
())
->
shape
;
}
/*!
* \brief Take elements from the tensor
* \param args The indices
* \return the result expression representing tensor read.
*/
template
<
typename
...
Args
>
inline
Expr
operator
()(
Args
&&
...
args
)
const
{
Array
<
Expr
>
indices
{
std
::
forward
<
Args
>
(
args
)...};
CHECK_EQ
(
ndim
(),
indices
.
size
())
<<
"Tensor dimension mismatch in read"
;
return
Expr
{};
return
operator
()(
indices
);
}
/*!
* \brief Take elements from the tensor
* \param indices the indices.
* \return the result expression representing tensor read.
*/
Expr
operator
()(
Array
<
Expr
>
indices
)
const
;
// printt function
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Tensor
&
t
)
{
// NOLINT(*)
os
<<
"Tensor(shape="
<<
t
.
shape
()
<<
", source="
<<
t
.
source
()
<<
", name="
<<
t
.
name
()
<<
')'
;
return
os
;
}
};
}
// namespace tvm
#endif // TVM_TENSOR_H_
src/expr/domain.cc
View file @
56e10eb0
...
...
@@ -22,7 +22,9 @@ Expr Range::extent() const {
RDomain
::
RDomain
(
Domain
domain
)
{
std
::
vector
<
Var
>
index
;
for
(
size_t
i
=
0
;
i
<
domain
.
size
();
++
i
)
{
index
.
push_back
(
Var
(
"reduction_index"
));
std
::
ostringstream
os
;
os
<<
"reduction_index"
<<
i
;
index
.
push_back
(
Var
(
os
.
str
()));
}
Array
<
Var
>
idx
(
index
);
node_
=
std
::
make_shared
<
RDomainNode
>
(
...
...
src/expr/expr.cc
View file @
56e10eb0
...
...
@@ -55,6 +55,11 @@ void Expr::Print(std::ostream& os) const {
os
<<
", "
<<
n
->
rdom
<<
')'
;
return
;
}
case
kTensorReadNode
:
{
const
auto
*
n
=
Get
<
TensorReadNode
>
();
os
<<
n
->
tensor
.
name
()
<<
n
->
indices
;
return
;
}
default
:
{
LOG
(
FATAL
)
<<
"not able to handle type "
<<
typeid
(
node_
.
get
()).
name
();
}
...
...
src/expr/expr_node.cc
View file @
56e10eb0
...
...
@@ -43,5 +43,6 @@ TVM_REGISTER_NODE_TYPE(FloatNode);
TVM_REGISTER_NODE_TYPE
(
UnaryOpNode
);
TVM_REGISTER_NODE_TYPE
(
BinaryOpNode
);
TVM_REGISTER_NODE_TYPE
(
ReduceNode
);
TVM_REGISTER_NODE_TYPE
(
TensorReadNode
);
}
// namespace tvm
src/expr/tensor.cc
0 → 100644
View file @
56e10eb0
/*!
* Copyright (c) 2016 by Contributors
* \file tensor.cc
*/
#include <tvm/tensor.h>
#include <tvm/expr_node.h>
#include <memory>
namespace
tvm
{
Tensor
::
Tensor
(
Array
<
Expr
>
shape
,
std
::
string
name
,
DataType
dtype
)
{
auto
node
=
std
::
make_shared
<
TensorNode
>
();
node
->
name
=
std
::
move
(
name
);
node
->
dtype
=
dtype
;
node
->
shape
=
std
::
move
(
shape
);
node_
=
std
::
move
(
node
);
}
Tensor
::
Tensor
(
Array
<
Expr
>
shape
,
FCompute
fcompute
,
std
::
string
name
)
{
auto
node
=
std
::
make_shared
<
TensorNode
>
();
node
->
name
=
std
::
move
(
name
);
node
->
shape
=
std
::
move
(
shape
);
size_t
ndim
=
node
->
shape
.
size
();
std
::
vector
<
Var
>
dim_index
;
for
(
size_t
i
=
0
;
i
<
ndim
;
++
i
)
{
std
::
ostringstream
os
;
os
<<
"dim_index"
<<
i
;
dim_index
.
push_back
(
Var
(
os
.
str
()));
}
node
->
dim_index
=
Array
<
Var
>
(
dim_index
);
node
->
source
=
fcompute
(
node
->
dim_index
);
node
->
dtype
=
node
->
source
.
dtype
();
node_
=
std
::
move
(
node
);
}
Expr
Tensor
::
operator
()(
Array
<
Expr
>
indices
)
const
{
CHECK_EQ
(
ndim
(),
indices
.
size
())
<<
"Tensor dimension mismatch in read"
<<
"ndim = "
<<
ndim
()
<<
", indices.size="
<<
indices
.
size
();
auto
node
=
std
::
make_shared
<
TensorReadNode
>
();
node
->
tensor
=
*
this
;
node
->
indices
=
std
::
move
(
indices
);
return
Expr
(
std
::
move
(
node
));
}
TVM_REGISTER_NODE_TYPE
(
TensorNode
);
}
// namespace tvm
tests/cpp/tensor_test.cc
View file @
56e10eb0
...
...
@@ -5,9 +5,14 @@
TEST
(
Tensor
,
Basic
)
{
using
namespace
tvm
;
Var
m
,
n
,
k
;
Tensor
A
({
m
,
k
});
Tensor
B
({
n
,
k
});
Var
m
(
"m"
),
n
(
"n"
),
l
(
"l"
);
Tensor
A
({
m
,
l
},
"A"
);
Tensor
B
({
n
,
l
},
"B"
);
RDomain
rd
({{
0
,
l
}});
auto
C
=
Tensor
({
m
,
n
},
[
&
](
Var
i
,
Var
j
)
{
return
sum
(
A
(
i
,
rd
.
i0
())
*
B
(
j
,
rd
.
i0
()),
rd
);
},
"C"
);
}
int
main
(
int
argc
,
char
**
argv
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment