Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
28876530
Commit
28876530
authored
Nov 28, 2016
by
tqchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add AttrStmt
parent
61de73b4
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
122 additions
and
83 deletions
+122
-83
HalideIR
+1
-1
include/tvm/ir.h
+29
-0
include/tvm/operation.h
+1
-0
include/tvm/schedule.h
+14
-42
src/c_api/c_api_ir.cc
+10
-23
src/lang/ir.cc
+25
-6
src/lang/operation.cc
+0
-2
src/lang/schedule.cc
+0
-1
src/pass/inline.cc
+2
-1
src/pass/ir_mutator.cc
+12
-0
src/pass/ir_visitor.cc
+6
-0
src/pass/schedule_ops.cc
+12
-4
tests/python/test_basic.py
+10
-3
No files found.
HalideIR
@
bf96f8af
Subproject commit
7f1d811972bccc26f651ea2289d88bcadea9fe9f
Subproject commit
bf96f8af0dfd1f79d258c7c1506f9ded932b94a9
include/tvm/ir.h
View file @
28876530
...
...
@@ -17,6 +17,7 @@ namespace tvm {
namespace
ir
{
using
Halide
::
Internal
::
ExprNode
;
using
Halide
::
Internal
::
StmtNode
;
using
Halide
::
Internal
::
IRNodeType
;
using
Halide
::
Internal
::
ForType
;
...
...
@@ -47,6 +48,34 @@ struct Reduce : public ExprNode<Reduce> {
static
constexpr
const
char
*
Min
=
"Min"
;
};
/*!
* \brief Define certain auxiliary attribute for the body to be a symbolic value.
* This is used to insert hint(shape, storage, split) about certain scopes.
*/
struct
AttrStmt
:
public
StmtNode
<
AttrStmt
>
{
/*! \brief this is attribute about certain node */
NodeRef
node
;
/*! \brief the type key of the attribute */
std
::
string
type_key
;
/*! \brief The attribute value, value is well defined at current scope. */
Expr
value
;
/*! \brief The body statement to be executed */
Stmt
body
;
/*! \brief construct expr from name and rdom */
static
Stmt
make
(
NodeRef
node
,
std
::
string
type_key
,
Expr
value
,
Stmt
body
);
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"node"
,
&
node
);
v
->
Visit
(
"type_key"
,
&
type_key
);
v
->
Visit
(
"value"
,
&
value
);
v
->
Visit
(
"body"
,
&
body
);
}
static
const
IRNodeType
_type_info
=
IRNodeType
::
ExtensionExpr
;
static
constexpr
const
char
*
_type_key
=
"AttrStmt"
;
};
// Reuse IR node defintiion from HalideIR
using
Halide
::
Internal
::
IntImm
;
using
Halide
::
Internal
::
UIntImm
;
...
...
include/tvm/operation.h
View file @
28876530
...
...
@@ -32,6 +32,7 @@ class ComputeOpNode : public OperationNode {
std
::
string
output_name
(
size_t
i
)
const
final
;
Type
output_dtype
(
size_t
i
)
const
final
;
Array
<
Expr
>
output_shape
(
size_t
i
)
const
final
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"domain"
,
&
domain
);
v
->
Visit
(
"name"
,
&
name
);
...
...
include/tvm/schedule.h
View file @
28876530
...
...
@@ -38,42 +38,7 @@ class Schedule : public NodeRef {
inline
const
ScheduleNode
*
operator
->
()
const
;
};
/*! \brief schedule container */
class
AttachSpec
:
public
NodeRef
{
public
:
AttachSpec
()
{}
explicit
AttachSpec
(
std
::
shared_ptr
<
Node
>
n
)
:
NodeRef
(
n
)
{}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline
const
AttachSpecNode
*
operator
->
()
const
;
};
// defintion of node containers
/*! \brief The attach specification of each subschedule */
class
AttachSpecNode
:
public
Node
{
public
:
/*! \brief The attachment type */
AttachType
attach_type
;
/*!
* \brief The split to be attached to,
* only valid when attach_type is kRoot
*/
Split
attach_split
;
/*! \brief the child schedule to be attached. */
Schedule
schedule
;
const
char
*
type_key
()
const
final
{
return
"AttachSpec"
;
}
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"attach_type"
,
&
attach_type
);
v
->
Visit
(
"attach_split"
,
&
attach_split
);
v
->
Visit
(
"schedule"
,
&
schedule
);
}
};
/*! \brief represents the schedule of the tensor */
class
ScheduleNode
:
public
Node
{
public
:
...
...
@@ -83,8 +48,17 @@ class ScheduleNode : public Node {
std
::
string
scope
;
/*! \brief Splits over iteration domains */
Array
<
Split
>
splits
;
/*! \brief attach specifications */
Array
<
AttachSpec
>
attachs
;
/*! \brief The attachment type of the schedule */
AttachType
attach_type
;
/*!
* \brief The attach point of this schedule, if it is a split
* \note This is not a cyclic dependency,
* because split do not refer back to parent schedule.
*/
Split
attach_parent
;
/*! \brief the schedules that this schedule depend on */
Array
<
Schedule
>
children
;
// the type key
const
char
*
type_key
()
const
final
{
return
"Schedule"
;
}
...
...
@@ -92,7 +66,9 @@ class ScheduleNode : public Node {
v
->
Visit
(
"scope"
,
&
scope
);
v
->
Visit
(
"op"
,
&
op
);
v
->
Visit
(
"splits"
,
&
splits
);
v
->
Visit
(
"attachs"
,
&
attachs
);
v
->
Visit
(
"attach_type"
,
&
attach_type
);
v
->
Visit
(
"attach_parent"
,
&
attach_parent
);
v
->
Visit
(
"children"
,
&
children
);
}
};
...
...
@@ -101,9 +77,5 @@ inline const ScheduleNode* Schedule::operator->() const {
return
static_cast
<
const
ScheduleNode
*>
(
node_
.
get
());
}
inline
const
AttachSpecNode
*
AttachSpec
::
operator
->
()
const
{
return
static_cast
<
const
AttachSpecNode
*>
(
node_
.
get
());
}
}
// namespace tvm
#endif // TVM_SCHEDULE_H_
src/c_api/c_api_ir.cc
View file @
28876530
...
...
@@ -29,13 +29,6 @@ TVM_REGISTER_API(_make_For)
args
.
at
(
5
));
});
TVM_REGISTER_API
(
_make_Reduce
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
*
ret
=
Reduce
::
make
(
args
.
at
(
0
),
args
.
at
(
1
),
args
.
at
(
2
));
});
TVM_REGISTER_API
(
_make_Call
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
*
ret
=
Call
::
make
(
args
.
at
(
0
),
...
...
@@ -54,22 +47,6 @@ TVM_REGISTER_API(_make_Allocate)
args
.
at
(
4
));
});
TVM_REGISTER_API
(
_make_LetStmt
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
if
(
args
.
size
()
==
3
)
{
*
ret
=
LetStmt
::
make
(
args
.
at
(
0
),
args
.
at
(
1
),
args
.
at
(
2
));
}
else
{
CHECK_EQ
(
args
.
size
(),
5
);
*
ret
=
LetStmt
::
make
(
args
.
at
(
0
),
args
.
at
(
1
),
args
.
at
(
2
),
args
.
at
(
3
),
args
.
at
(
4
));
}
});
// make from two arguments
#define REGISTER_MAKE1(Node) \
TVM_REGISTER_API(_make_## Node) \
...
...
@@ -89,6 +66,12 @@ TVM_REGISTER_API(_make_LetStmt)
*ret = Node::make(args.at(0), args.at(1), args.at(2)); \
}) \
#define REGISTER_MAKE4(Node) \
TVM_REGISTER_API(_make_## Node) \
.set_body([](const ArgStack& args, RetValue *ret) { \
*ret = Node::make(args.at(0), args.at(1), args.at(2), args.at(3)); \
}) \
#define REGISTER_MAKE_BINARY_OP(Node) \
TVM_REGISTER_API(_make_## Node) \
.set_body([](const ArgStack& args, RetValue *ret) { \
...
...
@@ -99,6 +82,9 @@ TVM_REGISTER_API(_make_LetStmt)
.add_argument("lhs", "Expr", "left operand") \
.add_argument("rhs", "Expr", "right operand")
REGISTER_MAKE3
(
Reduce
);
REGISTER_MAKE4
(
AttrStmt
);
REGISTER_MAKE2
(
IntImm
);
REGISTER_MAKE2
(
UIntImm
);
REGISTER_MAKE2
(
FloatImm
);
...
...
@@ -123,6 +109,7 @@ REGISTER_MAKE3(Select);
REGISTER_MAKE3
(
Ramp
);
REGISTER_MAKE2
(
Broadcast
);
REGISTER_MAKE3
(
Let
);
REGISTER_MAKE3
(
LetStmt
);
REGISTER_MAKE2
(
AssertStmt
);
REGISTER_MAKE3
(
ProducerConsumer
);
REGISTER_MAKE3
(
Store
);
...
...
src/lang/ir.cc
View file @
28876530
...
...
@@ -18,10 +18,16 @@ namespace Halide {
namespace
Internal
{
using
tvm
::
ir
::
Reduce
;
using
tvm
::
ir
::
AttrStmt
;
template
<>
void
ExprNode
<
Reduce
>::
accept
(
IRVisitor
*
v
,
const
Expr
&
)
const
{
LOG
(
FATAL
)
<<
"Reduce do not work with IRVisitor yet"
;
LOG
(
FATAL
)
<<
"Reduce do not work with old Visitor, use IRFunctor style visitor"
;
}
template
<>
void
StmtNode
<
AttrStmt
>::
accept
(
IRVisitor
*
v
,
const
Stmt
&
)
const
{
LOG
(
FATAL
)
<<
"AttrStmt do not work with old Visitor, use IRFunctor style visitor"
;
}
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
...
...
@@ -33,15 +39,20 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p
->
stream
<<
", rdom="
<<
op
->
rdom
<<
")"
;
});
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
.
set_dispatch
<
AttrStmt
>
([](
const
AttrStmt
*
op
,
IRPrinter
*
p
)
{
p
->
stream
<<
"attr "
<<
op
->
type_key
<<
" = "
;
p
->
print
(
op
->
value
);
p
->
stream
<<
'\n'
;
p
->
print
(
op
->
body
);
});
}
// namespace Internal
}
// namespace Halide
namespace
tvm
{
namespace
ir
{
// reduce
TVM_REGISTER_NODE_TYPE
(
Reduce
);
Expr
Reduce
::
make
(
std
::
string
op
,
Expr
source
,
RDomain
rdom
)
{
auto
n
=
std
::
make_shared
<
Reduce
>
();
CHECK
(
source
.
defined
());
...
...
@@ -52,9 +63,17 @@ Expr Reduce::make(std::string op, Expr source, RDomain rdom) {
return
Expr
(
n
);
}
Stmt
AttrStmt
::
make
(
NodeRef
node
,
std
::
string
type_key
,
Expr
value
,
Stmt
body
)
{
auto
n
=
std
::
make_shared
<
AttrStmt
>
();
n
->
node
=
node
;
n
->
type_key
=
type_key
;
n
->
value
=
value
;
n
->
body
=
body
;
return
Stmt
(
n
);
}
// HalideIR node
using
namespace
Halide
::
Internal
;
TVM_REGISTER_NODE_TYPE
(
Reduce
);
TVM_REGISTER_NODE_TYPE
(
AttrStmt
)
;
TVM_REGISTER_NODE_TYPE
(
FloatImm
);
TVM_REGISTER_NODE_TYPE
(
IntImm
);
...
...
src/lang/operation.cc
View file @
28876530
...
...
@@ -74,8 +74,6 @@ Array<Expr> ComputeOpNode::output_shape(size_t i) const {
return
Array
<
Expr
>
(
shape
);
}
TVM_REGISTER_NODE_TYPE
(
ComputeOpNode
);
}
// namespace tvm
src/lang/schedule.cc
View file @
28876530
...
...
@@ -13,7 +13,6 @@ Schedule::Schedule(Operation op, std::string scope) {
node_
=
n
;
}
TVM_REGISTER_NODE_TYPE
(
AttachSpecNode
);
TVM_REGISTER_NODE_TYPE
(
ScheduleNode
);
}
// namespace tvm
src/pass/inline.cc
View file @
28876530
...
...
@@ -19,11 +19,12 @@ class IRInline : public IRMutator {
:
f_
(
f
),
args_
(
args
),
body_
(
body
)
{}
Expr
Mutate
(
Expr
expr
)
final
{
expr
=
IRMutator
::
Mutate
(
expr
);
const
Call
*
call
=
expr
.
as
<
Call
>
();
if
(
call
!=
nullptr
&&
call
->
func
==
f_
)
{
return
InlineCall
(
call
);
}
else
{
return
IRMutator
::
Mutate
(
expr
)
;
return
expr
;
}
}
...
...
src/pass/ir_mutator.cc
View file @
28876530
...
...
@@ -72,6 +72,18 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
}
});
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_stmt
)
.
set_dispatch
<
AttrStmt
>
([](
const
AttrStmt
*
op
,
const
Stmt
&
s
,
IRMutator
*
m
)
{
Expr
value
=
m
->
Mutate
(
op
->
value
);
Stmt
body
=
m
->
Mutate
(
op
->
body
);
if
(
value
.
same_as
(
op
->
value
)
&&
body
.
same_as
(
op
->
body
))
{
return
s
;
}
else
{
return
AttrStmt
::
make
(
op
->
node
,
op
->
type_key
,
op
->
value
,
op
->
body
);
}
});
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_expr
)
.
set_dispatch
<
IntImm
>
(
ReturnSelfExpr
)
.
set_dispatch
<
UIntImm
>
(
ReturnSelfExpr
)
...
...
src/pass/ir_visitor.cc
View file @
28876530
...
...
@@ -66,6 +66,12 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
});
TVM_STATIC_IR_FUNCTOR
(
IRVisitor
,
vtable
)
.
set_dispatch
<
AttrStmt
>
([](
const
AttrStmt
*
op
,
IRVisitor
*
v
)
{
v
->
Visit
(
op
->
value
);
v
->
Visit
(
op
->
body
);
});
TVM_STATIC_IR_FUNCTOR
(
IRVisitor
,
vtable
)
.
set_dispatch
<
IntImm
>
(
NoOp
)
.
set_dispatch
<
UIntImm
>
(
NoOp
)
.
set_dispatch
<
FloatImm
>
(
NoOp
)
...
...
src/pass/schedule_ops.cc
View file @
28876530
...
...
@@ -13,11 +13,19 @@ namespace {
// inject the operator's realization on the stmt.
class
InjectRealize
:
public
IRMutator
{
public
:
explicit
InjectRealize
(
std
::
vector
<
Tensor
>
tensors
)
:
tensors_
(
tensors
)
{}
std
::
vector
<
Tensor
>
tensors_
;
};
explicit
InjectRealize
(
Schedule
sch
)
:
sch_
(
sch
)
{}
Stmt
Mutate
(
Stmt
stmt
)
final
{
stmt
=
IRMutator
::
Mutate
(
stmt
);
const
For
*
op
=
stmt
.
as
<
For
>
();
return
stmt
;
}
private
:
// the operations to be carried
Schedule
sch_
;
};
}
// namespace
}
// namespace ir
...
...
tests/python/test_basic.py
View file @
28876530
...
...
@@ -22,10 +22,15 @@ def test_let():
x
=
tvm
.
Var
(
'x'
)
y
=
tvm
.
Var
(
'y'
)
stmt
=
tvm
.
make
.
LetStmt
(
x
,
10
,
tvm
.
make
.
Evaluate
(
x
+
1
),
y
,
"stride"
)
assert
stmt
.
attr_of_node
==
y
print
(
stmt
)
x
,
10
,
tvm
.
make
.
Evaluate
(
x
+
1
));
def
test_attr
():
x
=
tvm
.
Var
(
'x'
)
y
=
tvm
.
Var
(
'y'
)
stmt
=
tvm
.
make
.
AttrStmt
(
y
,
"stride"
,
10
,
tvm
.
make
.
Evaluate
(
x
+
1
));
assert
stmt
.
node
==
y
print
(
stmt
)
def
test_basic
():
a
=
tvm
.
Var
(
'a'
)
...
...
@@ -44,6 +49,8 @@ def test_stmt():
if
__name__
==
"__main__"
:
test_attr
()
test_const
()
test_make
()
test_ir
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment