Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
70d93028
Commit
70d93028
authored
Nov 29, 2016
by
tqchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Keep up with changes of NodeRef
parent
2fc12dcd
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
153 additions
and
46 deletions
+153
-46
HalideIR
+1
-1
include/tvm/domain.h
+3
-3
include/tvm/ir_mutator.h
+4
-4
include/tvm/ir_pass.h
+1
-1
include/tvm/ir_visitor.h
+4
-4
include/tvm/operation.h
+3
-3
include/tvm/schedule.h
+4
-0
include/tvm/split.h
+4
-3
include/tvm/tensor.h
+4
-3
include/tvm/tvm.h
+1
-0
src/c_api/c_api_function.cc
+2
-2
src/pass/ir_mutator.cc
+1
-1
src/pass/ir_visitor.cc
+5
-5
src/pass/schedule_ops.cc
+25
-9
src/pass/scope.h
+84
-0
src/pass/ssa.cc
+4
-4
tests/cpp/ir_functor_test.cc
+2
-2
tests/cpp/ir_visitor_test.cc
+1
-1
No files found.
HalideIR
@
eb2f7d60
Subproject commit
bf96f8af0dfd1f79d258c7c1506f9ded932b94a9
Subproject commit
eb2f7d604a611318fc685172847bcf5ba2fcf835
include/tvm/domain.h
View file @
70d93028
...
...
@@ -95,13 +95,13 @@ class RDomainNode : public Node {
RDomainNode
(
Array
<
Var
>
index
,
Domain
domain
)
:
index
(
index
),
domain
(
domain
)
{
}
const
char
*
type_key
()
const
override
{
return
"RDomain"
;
}
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"index"
,
&
index
);
v
->
Visit
(
"domain"
,
&
domain
);
}
static
constexpr
const
char
*
_type_key
=
"RDomain"
;
TVM_DECLARE_NODE_TYPE_INFO
(
RDomainNode
);
};
inline
const
RDomainNode
*
RDomain
::
operator
->
()
const
{
...
...
include/tvm/ir_mutator.h
View file @
70d93028
...
...
@@ -6,7 +6,7 @@
#ifndef TVM_IR_MUTATOR_H_
#define TVM_IR_MUTATOR_H_
#include <tvm/ir_
node
.h>
#include <tvm/ir_
functor
.h>
#include <unordered_map>
#include "./expr.h"
...
...
@@ -16,7 +16,7 @@ namespace ir {
* \brief a base class for mutator to iterative mutate the IR
*
* This IRMutator is implemented via IRFunctor instead of Visitor Pattern.
* This enables easy extensions of possible new
IR
Node.
* This enables easy extensions of possible new Node.
* It also makes changing return types easier.
*
* \note If you want to return a different type other than Expr and Stmt,
...
...
@@ -44,9 +44,9 @@ class IRMutator {
/*! \brief destructor */
virtual
~
IRMutator
()
{}
/*! \brief functor type of expr mutation */
using
FMutateExpr
=
IRFunctor
<
Expr
(
const
IR
NodeRef
&
,
const
Expr
&
,
IRMutator
*
)
>
;
using
FMutateExpr
=
IRFunctor
<
Expr
(
const
NodeRef
&
,
const
Expr
&
,
IRMutator
*
)
>
;
/*! \brief functor type of stmt mutation */
using
FMutateStmt
=
IRFunctor
<
Stmt
(
const
IR
NodeRef
&
,
const
Stmt
&
,
IRMutator
*
)
>
;
using
FMutateStmt
=
IRFunctor
<
Stmt
(
const
NodeRef
&
,
const
Stmt
&
,
IRMutator
*
)
>
;
/*! \return internal vtable of expr */
static
FMutateExpr
&
vtable_expr
();
// NOLINT(*)
/*! \return internal stmt of expr */
...
...
include/tvm/ir_pass.h
View file @
70d93028
...
...
@@ -9,7 +9,7 @@
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_
#include <tvm/ir_
node
.h>
#include <tvm/ir_
functor
.h>
#include <unordered_map>
#include <vector>
#include "./expr.h"
...
...
include/tvm/ir_visitor.h
View file @
70d93028
...
...
@@ -15,7 +15,7 @@ namespace ir {
* \brief a base class for visitor to iterative traverse the IR
*
* This IRVisitor is implemented via IRFunctor
* This enables extensions of possible new
IR
Node.
* This enables extensions of possible new Node.
*
* \sa IRFunctor, PostOrderVisit
*/
...
...
@@ -24,14 +24,14 @@ class IRVisitor {
/*!
* \brief recursively visit an IR node
*/
virtual
void
Visit
(
const
IR
NodeRef
&
node
)
{
virtual
void
Visit
(
const
NodeRef
&
node
)
{
static
const
FVisit
&
f
=
vtable
();
if
(
node
.
defined
())
f
(
node
,
this
);
}
/*! \brief destructor */
virtual
~
IRVisitor
()
{}
/*! \brief functor type of visitor */
using
FVisit
=
IRFunctor
<
void
(
const
IR
NodeRef
&
,
IRVisitor
*
)
>
;
using
FVisit
=
IRFunctor
<
void
(
const
NodeRef
&
,
IRVisitor
*
)
>
;
/*! \return internal vtable*/
static
FVisit
&
vtable
();
};
...
...
@@ -42,7 +42,7 @@ class IRVisitor {
* \param node The ir to be visited.
* \param fvisit The visitor function to be applied.
*/
void
PostOrderVisit
(
const
IRNodeRef
&
node
,
std
::
function
<
void
(
const
IR
NodeRef
&
)
>
fvisit
);
void
PostOrderVisit
(
const
NodeRef
&
node
,
std
::
function
<
void
(
const
NodeRef
&
)
>
fvisit
);
}
// namespace ir
}
// namespace tvm
...
...
include/tvm/operation.h
View file @
70d93028
...
...
@@ -23,9 +23,6 @@ class ComputeOpNode : public OperationNode {
/*! \brief constructor */
ComputeOpNode
()
{}
const
char
*
type_key
()
const
final
{
return
"ComputeOp"
;
}
size_t
num_outputs
()
const
final
{
return
1
;
}
...
...
@@ -43,6 +40,9 @@ class ComputeOpNode : public OperationNode {
std
::
string
name
,
Array
<
Var
>
dim_var
,
Expr
body
);
static
constexpr
const
char
*
_type_key
=
"ComputeOp"
;
TVM_DECLARE_NODE_TYPE_INFO
(
ComputeOpNode
);
};
...
...
include/tvm/schedule.h
View file @
70d93028
...
...
@@ -62,6 +62,10 @@ class ScheduleNode : public Node {
const
char
*
type_key
()
const
final
{
return
"Schedule"
;
}
const
uint32_t
type_index
()
const
final
{
static
uint32_t
tidx
=
TypeKey2Index
(
type_key
());
return
tidx
;
}
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"scope"
,
&
scope
);
v
->
Visit
(
"op"
,
&
op
);
...
...
include/tvm/split.h
View file @
70d93028
...
...
@@ -46,14 +46,15 @@ class DimSplitNode : public SplitNode {
Expr
factor
;
/*! \brief constructor */
DimSplitNode
()
{}
const
char
*
type_key
()
const
final
{
return
"DimSplit"
;
}
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"var"
,
&
var
);
v
->
Visit
(
"factor"
,
&
factor
);
}
static
Split
make
(
Var
var
,
Expr
factor
);
static
constexpr
const
char
*
_type_key
=
"DimSplit"
;
TVM_DECLARE_NODE_TYPE_INFO
(
DimSplitNode
);
};
// Implementations of inline functions
...
...
include/tvm/tensor.h
View file @
70d93028
...
...
@@ -104,9 +104,7 @@ class TensorNode : public FunctionBaseNode {
int
value_index
{
0
};
/*! \brief constructor */
TensorNode
()
{}
const
char
*
type_key
()
const
final
{
return
"Tensor"
;
}
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"shape"
,
&
shape
);
v
->
Visit
(
"name"
,
&
name
);
...
...
@@ -125,6 +123,9 @@ class TensorNode : public FunctionBaseNode {
Type
dtype
,
Operation
op
,
int
value_index
);
static
constexpr
const
char
*
_type_key
=
"Tensor"
;
TVM_DECLARE_NODE_TYPE_INFO
(
TensorNode
);
};
/*!
...
...
include/tvm/tvm.h
View file @
70d93028
...
...
@@ -9,5 +9,6 @@
#include "./base.h"
#include "./expr.h"
#include "./tensor.h"
#include "./operation.h"
#endif // TVM_TVM_H_
src/c_api/c_api_function.cc
View file @
70d93028
...
...
@@ -26,9 +26,9 @@ TVM_REGISTER_API(_format_str)
CHECK
(
args
.
at
(
0
).
type_id
==
kNodeHandle
);
std
::
ostringstream
os
;
auto
&
sptr
=
args
.
at
(
0
).
sptr
;
if
(
sptr
->
is_type
<
TensorNode
>
(
))
{
if
(
dynamic_cast
<
const
TensorNode
*>
(
sptr
.
get
()
))
{
os
<<
args
.
at
(
0
).
operator
Tensor
();
}
else
if
(
sptr
->
is_type
<
RDomainNode
>
(
))
{
}
else
if
(
dynamic_cast
<
const
RDomainNode
*>
(
sptr
.
get
()
))
{
os
<<
args
.
at
(
0
).
operator
RDomain
();
}
else
if
(
dynamic_cast
<
const
BaseExprNode
*>
(
sptr
.
get
()))
{
os
<<
args
.
at
(
0
).
operator
Expr
();
...
...
src/pass/ir_mutator.cc
View file @
70d93028
...
...
@@ -22,7 +22,7 @@ namespace {
using
namespace
Halide
::
Internal
;
// const expr
inline
Expr
ReturnSelfExpr
(
const
IR
NodeRef
&
,
const
Expr
&
e
,
IRMutator
*
)
{
inline
Expr
ReturnSelfExpr
(
const
NodeRef
&
,
const
Expr
&
e
,
IRMutator
*
)
{
return
e
;
}
...
...
src/pass/ir_visitor.cc
View file @
70d93028
...
...
@@ -12,9 +12,9 @@ namespace {
// visitor to implement apply
class
IRApplyVisit
:
public
IRVisitor
{
public
:
explicit
IRApplyVisit
(
std
::
function
<
void
(
const
IR
NodeRef
&
)
>
f
)
:
f_
(
f
)
{}
explicit
IRApplyVisit
(
std
::
function
<
void
(
const
NodeRef
&
)
>
f
)
:
f_
(
f
)
{}
void
Visit
(
const
IR
NodeRef
&
node
)
final
{
void
Visit
(
const
NodeRef
&
node
)
final
{
if
(
visited_
.
count
(
node
.
get
())
!=
0
)
return
;
visited_
.
insert
(
node
.
get
());
IRVisitor
::
Visit
(
node
);
...
...
@@ -22,13 +22,13 @@ class IRApplyVisit : public IRVisitor {
}
private
:
std
::
function
<
void
(
const
IR
NodeRef
&
)
>
f_
;
std
::
function
<
void
(
const
NodeRef
&
)
>
f_
;
std
::
unordered_set
<
const
Node
*>
visited_
;
};
}
// namespace
void
PostOrderVisit
(
const
IRNodeRef
&
node
,
std
::
function
<
void
(
const
IR
NodeRef
&
)
>
fvisit
)
{
void
PostOrderVisit
(
const
NodeRef
&
node
,
std
::
function
<
void
(
const
NodeRef
&
)
>
fvisit
)
{
IRApplyVisit
(
fvisit
).
Visit
(
node
);
}
...
...
@@ -42,7 +42,7 @@ namespace {
using
namespace
Halide
::
Internal
;
void
NoOp
(
const
IR
NodeRef
&
n
,
IRVisitor
*
v
)
{
void
NoOp
(
const
NodeRef
&
n
,
IRVisitor
*
v
)
{
}
inline
void
VisitArray
(
Array
<
Expr
>
arr
,
IRVisitor
*
v
)
{
...
...
src/pass/schedule_ops.cc
View file @
70d93028
...
...
@@ -5,21 +5,37 @@
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include "./scope.h"
namespace
tvm
{
namespace
ir
{
namespace
{
Stmt
MakeCompute
(
const
ComputeOpNode
*
op
,
const
Array
<
Split
>&
splits
)
{
Tensor
output
;
std
::
vector
<
Expr
>
args
(
op
->
dim_var
.
size
());
for
(
size_t
i
=
0
;
i
<
args
.
size
();
++
i
)
{
args
[
i
]
=
op
->
dim_var
[
i
];
/*!
* \brief make nest loops given list of stmt, whose body is not defined.
* \param nest A list of For and LetStmt, whose body is not defined.
* \param body The inner-most body of the loop
*/
Stmt
MakeLoop
(
std
::
vector
<
Stmt
>&&
nest
,
Stmt
body
)
{
while
(
!
nest
.
empty
())
{
Stmt
s
=
std
::
move
(
nest
.
back
());
nest
.
pop_back
();
if
(
s
.
as
<
For
>
())
{
auto
n
=
std
::
make_shared
<
For
>
(
*
s
.
as
<
For
>
());
n
->
body
=
body
;
body
=
Stmt
(
n
);
}
else
if
(
s
.
as
<
LetStmt
>
())
{
auto
n
=
std
::
make_shared
<
LetStmt
>
(
*
s
.
as
<
LetStmt
>
());
n
->
body
=
body
;
body
=
Stmt
(
n
);
}
else
if
(
s
.
as
<
AttrStmt
>
())
{
auto
n
=
std
::
make_shared
<
AttrStmt
>
(
*
s
.
as
<
AttrStmt
>
());
n
->
body
=
body
;
body
=
Stmt
(
n
);
}
else
{
LOG
(
FATAL
)
<<
"not supported nest type"
;
}
}
Array
<
Expr
>
values
{
op
->
body
};
Stmt
stmt
=
Provide
::
make
(
output
,
values
,
args
);
// add splits from ousside most to outsidemost to innermost
return
stmt
;
return
body
;
}
...
...
src/pass/scope.h
0 → 100644
View file @
70d93028
/*!
* Copyright (c) 2016 by Contributors
* \file scope.h
* \brief attribute scope data structure,
* defines attributes on current domain
*/
#ifndef TVM_PASS_SCOPE_H_
#define TVM_PASS_SCOPE_H_
#include <tvm/ir.h>
#include <unordered_map>
#include <vector>
#include <string>
namespace
tvm
{
namespace
ir
{
/*!
* \brief Attribute scope of Nodes in the IR.
* \tparam ValueType The value of of the scope.
*/
template
<
typename
K
,
typename
V
>
class
Scope
{
public
:
/*!
* \brief Push value to scope
* \param key the key to be pushed.
* \param v The value to be pushed.
*/
inline
void
Push
(
const
K
&
key
,
V
v
)
{
data_
[
key
].
emplace_back
(
v
);
}
/*!
* \brief Pop value from scope.
* \param key the key to be poped
*/
inline
void
Pop
(
const
K
&
key
)
{
auto
&
v
=
data_
[
key
];
CHECK_NE
(
v
.
size
(),
0
);
v
.
pop_back
();
}
/*!
* \brief Get value from the scope
* \param key the key to fetch.
* \return The value to be fetched.
*/
inline
V
operator
[](
const
K
&
key
)
const
{
const
auto
it
=
data_
.
find
(
key
);
CHECK
(
it
!=
data_
.
end
()
&&
it
->
second
.
size
()
!=
0
)
<<
"cannot find value in scope"
;
return
it
->
second
.
back
();
}
private
:
std
::
unordered_map
<
K
,
std
::
vector
<
V
>
>
data_
;
};
/*! \brief Attribute key for specific attribute */
struct
AttrKey
{
/*! \brief The node of the attribute */
NodeRef
node
;
/*! \brief The type key of the attribute. */
std
::
string
type_key
;
// overload operator ==
inline
bool
operator
==
(
const
AttrKey
&
other
)
const
{
return
node
==
other
.
node
&&
type_key
==
other
.
type_key
;
}
};
}
// namespace ir
}
// namespace tvm
namespace
std
{
template
<>
struct
hash
<::
tvm
::
ir
::
AttrKey
>
{
std
::
size_t
operator
()(
const
::
tvm
::
ir
::
AttrKey
&
k
)
const
{
size_t
lhs
=
k
.
node
.
hash
();
size_t
rhs
=
std
::
hash
<
std
::
string
>
()(
k
.
type_key
);
lhs
^=
rhs
+
0x9e3779b9
+
(
lhs
<<
6
)
+
(
lhs
>>
2
);
return
lhs
;
}
};
}
// namespace std
#endif // TVM_PASS_SCOPE_H_
src/pass/ssa.cc
View file @
70d93028
...
...
@@ -17,7 +17,7 @@ namespace {
// global functor to get var definition from
struct
FGetVarDef
{
using
FType
=
IRFunctor
<
VarExpr
(
const
IR
NodeRef
&
)
>
;
using
FType
=
IRFunctor
<
VarExpr
(
const
NodeRef
&
)
>
;
static
FType
&
vtable
()
{
// NOLINT(*)
static
FType
inst
;
return
inst
;
}
...
...
@@ -37,8 +37,8 @@ TVM_STATIC_IR_FUNCTOR(FGetVarDef, vtable)
});
struct
FSetVarDef
{
using
FTypeExpr
=
IRFunctor
<
Expr
(
const
IR
NodeRef
&
,
VarExpr
)
>
;
using
FTypeStmt
=
IRFunctor
<
Stmt
(
const
IR
NodeRef
&
,
VarExpr
)
>
;
using
FTypeExpr
=
IRFunctor
<
Expr
(
const
NodeRef
&
,
VarExpr
)
>
;
using
FTypeStmt
=
IRFunctor
<
Stmt
(
const
NodeRef
&
,
VarExpr
)
>
;
static
FTypeExpr
&
vtable_expr
()
{
// NOLINT(*)
static
FTypeExpr
inst
;
return
inst
;
}
...
...
@@ -69,7 +69,7 @@ class IRVerifySSA : public IRVisitor {
public
:
bool
is_ssa
{
true
};
void
Visit
(
const
IR
NodeRef
&
n
)
final
{
void
Visit
(
const
NodeRef
&
n
)
final
{
if
(
!
is_ssa
)
return
;
static
auto
&
fget_var_def
=
FGetVarDef
::
vtable
();
if
(
fget_var_def
.
can_dispatch
(
n
))
{
...
...
tests/cpp/ir_functor_test.cc
View file @
70d93028
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/ir_
node
.h>
#include <tvm/ir_
functor
.h>
TEST
(
IRF
,
Basic
)
{
using
namespace
Halide
::
Internal
;
...
...
@@ -9,7 +9,7 @@ TEST(IRF, Basic) {
Var
x
(
"x"
);
auto
z
=
x
+
1
;
IRFunctor
<
int
(
const
IR
NodeRef
&
n
,
int
b
)
>
f
;
IRFunctor
<
int
(
const
NodeRef
&
n
,
int
b
)
>
f
;
LOG
(
INFO
)
<<
"x"
;
f
.
set_dispatch
<
Variable
>
([](
const
Variable
*
n
,
int
b
)
{
return
b
;
...
...
tests/cpp/ir_visitor_test.cc
View file @
70d93028
...
...
@@ -11,7 +11,7 @@ TEST(IRVisitor, CountVar) {
Var
x
(
"x"
),
y
;
auto
z
=
x
+
1
+
y
+
y
;
ir
::
PostOrderVisit
(
z
,
[
&
n_var
](
const
IR
NodeRef
&
n
)
{
ir
::
PostOrderVisit
(
z
,
[
&
n_var
](
const
NodeRef
&
n
)
{
if
(
n
.
as
<
Variable
>
())
++
n_var
;
});
CHECK_EQ
(
n_var
,
2
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment