Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
5f829774
Commit
5f829774
authored
Oct 19, 2016
by
tqchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add domain
parent
5324b211
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
192 additions
and
9 deletions
+192
-9
include/tvm/array.h
+12
-0
include/tvm/domain.h
+121
-2
include/tvm/expr_node.h
+1
-1
include/tvm/expr_util.h
+3
-1
src/expr/domain.cc
+36
-0
src/expr/expr.cc
+7
-0
src/expr/expr_node.cc
+1
-0
tests/cpp/expr_test.cc
+10
-0
tests/cpp/tensor_test.cc
+1
-5
No files found.
include/tvm/array.h
View file @
5f829774
...
...
@@ -128,6 +128,18 @@ class Array : public NodeRef {
if
(
node_
.
get
()
==
nullptr
)
return
0
;
return
static_cast
<
const
ArrayNode
*>
(
node_
.
get
())
->
data
.
size
();
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Array
<
T
>&
r
)
{
// NOLINT(*)
for
(
size_t
i
=
0
;
i
<
r
.
size
();
++
i
)
{
if
(
i
==
0
)
{
os
<<
'['
;
}
else
{
os
<<
", "
;
}
os
<<
r
[
i
];
}
os
<<
']'
;
return
os
;
}
};
}
// namespace tvm
...
...
include/tvm/domain.h
View file @
5f829774
...
...
@@ -13,14 +13,133 @@
namespace
tvm
{
/*! \brief range over one dimension */
class
RangeNode
:
public
Node
{
public
:
/*! \brief beginning of the node */
Expr
begin
;
/*! \brief end of the node */
Expr
end
;
/*! \brief constructor */
RangeNode
()
{}
RangeNode
(
Expr
&&
begin
,
Expr
&&
end
)
:
begin
(
std
::
move
(
begin
)),
end
(
std
::
move
(
end
))
{
}
const
char
*
type_key
()
const
override
{
return
"RangeNode"
;
}
void
VisitNodeRefFields
(
FNodeRefVisit
fvisit
)
override
{
fvisit
(
"begin"
,
&
begin
);
fvisit
(
"end"
,
&
end
);
}
void
VisitAttrs
(
AttrVisitor
*
visitor
)
override
{}
};
/*! \brief Node range */
class
Range
:
public
NodeRef
{
public
:
/*! \brief constructor */
Range
()
{}
/*!
* \brief constructor
* \param begin start of the range.
* \param end end of the range.
*/
Range
(
Expr
begin
,
Expr
end
);
/*! \return The extent of the range */
Expr
extent
()
const
;
/*! \return the begining of the range */
inline
const
Expr
&
begin
()
const
{
return
static_cast
<
const
RangeNode
*>
(
node_
.
get
())
->
begin
;
}
/*! \return the end of the range */
inline
const
Expr
&
end
()
const
{
return
static_cast
<
const
RangeNode
*>
(
node_
.
get
())
->
end
;
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Range
&
r
)
{
// NOLINT(*)
os
<<
'['
<<
r
.
begin
()
<<
", "
<<
r
.
end
()
<<
')'
;
return
os
;
}
};
//using Domain = Array<Range>;
/*! \brief Domain is a multi-dimensional range */
using
Domain
=
Array
<
Range
>
;
/*! \brief reduction domain node */
class
RDomainNode
:
public
Node
{
public
:
/*! \brief internal index */
Array
<
Var
>
index
;
/*! \brief The inernal domain */
Domain
domain
;
/*! \brief constructor */
RDomainNode
()
{}
RDomainNode
(
Array
<
Var
>
&&
index
,
Domain
&&
domain
)
:
index
(
std
::
move
(
index
)),
domain
(
std
::
move
(
domain
))
{
}
const
char
*
type_key
()
const
override
{
return
"RDomainNode"
;
}
void
VisitNodeRefFields
(
FNodeRefVisit
fvisit
)
override
{
fvisit
(
"index"
,
&
index
);
fvisit
(
"domain"
,
&
domain
);
}
void
VisitAttrs
(
AttrVisitor
*
visitor
)
override
{}
};
/*! \brief reduction domain */
class
RDomain
:
public
NodeRef
{
public
:
/*! \brief constructor*/
RDomain
()
{}
/*!
* constructor by domain
* \param domain The domain of reduction.
*/
explicit
RDomain
(
Domain
domain
);
/*!
* \brief constructor by list of ranges
* \param domain The reduction domain
*/
explicit
RDomain
(
std
::
initializer_list
<
Range
>
domain
)
:
RDomain
(
Domain
(
domain
))
{}
/*!
* \brief constructor from node pointer
* \param nptr Another node shared pointer
*/
explicit
RDomain
(
std
::
shared_ptr
<
Node
>&&
nptr
)
:
NodeRef
(
std
::
move
(
nptr
))
{
CHECK
(
node_
.
get
()
!=
nullptr
);
CHECK
(
node_
->
is_type
<
RDomainNode
>
());
}
/*! \return The dimension of the RDomain */
inline
size_t
ndim
()
const
{
return
static_cast
<
const
RDomainNode
*>
(
node_
.
get
())
->
index
.
size
();
}
/*! \return the 0-th index of the domain */
inline
Var
i0
()
const
{
return
index
(
0
);
}
/*!
* \param i the index.
* \return i-th index variable in the RDomain
*/
inline
Var
index
(
size_t
i
)
const
{
return
static_cast
<
const
RDomainNode
*>
(
node_
.
get
())
->
index
[
i
];
}
/*!
* \return The domain of the reduction.
*/
inline
const
Domain
&
domain
()
const
{
return
static_cast
<
const
RDomainNode
*>
(
node_
.
get
())
->
domain
;
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
RDomain
&
r
)
{
// NOLINT(*)
os
<<
"rdomain("
<<
r
.
domain
()
<<
")"
;
return
os
;
}
};
/*! \brief use RDom as alias of RDomain */
using
RDom
=
RDomain
;
}
// namespace tvm
...
...
include/tvm/expr_node.h
View file @
5f829774
...
...
@@ -11,8 +11,8 @@
#include "./tensor.h"
#include "./expr.h"
namespace
tvm
{
/*! \brief variable node for symbolic variables */
class
VarNode
:
public
ExprNode
{
public
:
...
...
include/tvm/expr_util.h
View file @
5f829774
...
...
@@ -16,7 +16,9 @@ namespace tvm {
* \param src The source expression
* \return the simplified expression.
*/
Expr
Simplify
(
const
Expr
&
src
);
inline
Expr
Simplify
(
Expr
src
)
{
return
src
;
}
/*!
* \brief visit the exression node in expr tree in post DFS order.
...
...
src/expr/domain.cc
0 → 100644
View file @
5f829774
/*!
* Copyright (c) 2016 by Contributors
* \file domain.cc
*/
#include <tvm/domain.h>
#include <tvm/op.h>
#include <tvm/expr_node.h>
#include <tvm/expr_util.h>
namespace
tvm
{
Range
::
Range
(
Expr
begin
,
Expr
end
)
{
node_
=
std
::
make_shared
<
RangeNode
>
(
std
::
move
(
begin
),
std
::
move
(
end
));
}
Expr
Range
::
extent
()
const
{
return
Simplify
(
end
()
-
begin
());
}
RDomain
::
RDomain
(
Domain
domain
)
{
std
::
vector
<
Var
>
index
;
for
(
size_t
i
=
0
;
i
<
domain
.
size
();
++
i
)
{
index
.
push_back
(
Var
(
"reduction_index"
));
}
Array
<
Var
>
idx
(
index
);
node_
=
std
::
make_shared
<
RDomainNode
>
(
std
::
move
(
idx
),
std
::
move
(
domain
));
}
TVM_REGISTER_NODE_TYPE
(
RangeNode
);
TVM_REGISTER_NODE_TYPE
(
ArrayNode
);
TVM_REGISTER_NODE_TYPE
(
RDomainNode
);
}
// namespace tvm
src/expr/expr.cc
View file @
5f829774
...
...
@@ -48,6 +48,13 @@ void Expr::Print(std::ostream& os) const {
os
<<
')'
;
return
;
}
case
kReduceNode
:
{
const
auto
*
n
=
Get
<
ReduceNode
>
();
os
<<
"reduce("
<<
n
->
op
->
FunctionName
()
<<
", "
;
n
->
src
.
Print
(
os
);
os
<<
", "
<<
n
->
rdom
<<
')'
;
return
;
}
default
:
{
LOG
(
FATAL
)
<<
"not able to handle type "
<<
typeid
(
node_
.
get
()).
name
();
}
...
...
src/expr/expr_node.cc
View file @
5f829774
...
...
@@ -42,5 +42,6 @@ TVM_REGISTER_NODE_TYPE(IntNode);
TVM_REGISTER_NODE_TYPE
(
FloatNode
);
TVM_REGISTER_NODE_TYPE
(
UnaryOpNode
);
TVM_REGISTER_NODE_TYPE
(
BinaryOpNode
);
TVM_REGISTER_NODE_TYPE
(
ReduceNode
);
}
// namespace tvm
tests/cpp/expr_test.cc
View file @
5f829774
...
...
@@ -11,6 +11,16 @@ TEST(Expr, Basic) {
CHECK
(
os
.
str
()
==
"max(((x + 1) + 2), 100)"
);
}
TEST
(
Expr
,
Reduction
)
{
using
namespace
tvm
;
Var
x
(
"x"
);
RDomain
rdom
({{
0
,
3
}});
auto
z
=
sum
(
x
+
1
+
2
,
rdom
);
std
::
ostringstream
os
;
os
<<
z
;
CHECK
(
os
.
str
()
==
"reduce(+, ((x + 1) + 2), rdomain([[0, 3)]))"
);
}
int
main
(
int
argc
,
char
**
argv
)
{
testing
::
InitGoogleTest
(
&
argc
,
argv
);
testing
::
FLAGS_gtest_death_test_style
=
"threadsafe"
;
...
...
tests/cpp/tensor_test.cc
View file @
5f829774
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
...
...
@@ -7,11 +8,6 @@ TEST(Tensor, Basic) {
Var
m
,
n
,
k
;
Tensor
A
({
m
,
k
});
Tensor
B
({
n
,
k
});
auto
x
=
[
=
](
Var
i
,
Var
j
,
Var
k
)
{
return
A
(
i
,
k
)
*
B
(
j
,
k
);
};
auto
C
=
Tensor
({
m
,
n
},
x
);
}
int
main
(
int
argc
,
char
**
argv
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment